autodiff/
dual.rs

1//! Dual numbers for forward-mode automatic differentiation.
2//!
3//! A dual number represents a value and its derivative simultaneously,
4//! enabling automatic computation of derivatives through operator
5//! overloading.
6//!
7//! # Mathematical Background
8//!
9//! A dual number has the form `a + a′·ε` where `ε² = 0` (and `a′` denotes
10//! the derivative with respect to the seeded input). Arithmetic operations
11//! on dual numbers follow these algebraic rules:
12//!
13//! - `(a + a′·ε) + (b + b′·ε) = (a+b) + (a′+b′)·ε`
14//! - `-(a + a′·ε) = -a + (-a′)·ε`
15//! - `(a + a′·ε) - (b + b′·ε) = (a-b) + (a′-b′)·ε`
16//! - `(a + a′·ε) * (b + b′·ε) = ab + (a′b + ab′)·ε`
17//! - `1/(b + b′·ε) = (1/b) + (-b′/b²)·ε`
18//! - `(a + a′·ε) / (b + b′·ε) = (a + a′·ε) * (1/(b + b′·ε))`
19//!
20//! The chain rule emerges implicitly from composing these operations—you
21//! never write it down explicitly.
22//!
23//! This is **forward-mode** automatic differentiation: we compute the
24//! derivative as we compute the function value.
25//!
26//! # Example
27//!
28//! ```
29//! use autodiff::Dual;
30//!
31//! // Compute f(x) = x² + 2x at x=3
32//! let x = Dual::variable(3.0);  // x with derivative dx/dx = 1
33//!
34//! let f = x * x + Dual::constant(2.0) * x;
35//!
36//! assert_eq!(f.value, 15.0);    // f(3) = 9 + 6 = 15
37//! assert_eq!(f.deriv, 8.0);     // f'(3) = 2*3 + 2 = 8
38//! ```
39//!
40//! # Reusable Functions
41//!
42//! The idiomatic way to compute derivatives at multiple points is to
43//! write the function once and evaluate it with different seeds:
44//!
45//! ```
46//! use autodiff::Dual;
47//!
48//! // Define the function once
49//! fn f(x: Dual<f64>) -> Dual<f64> {
50//!     x * x + Dual::constant(2.0) * x
51//! }
52//!
53//! // Helper to evaluate f and f' at a point
54//! fn eval_f_and_df(a: f64) -> (f64, f64) {
55//!     let y = f(Dual::variable(a));
56//!     (y.value, y.deriv)
57//! }
58//!
59//! // Evaluate at multiple points without rewriting the expression
60//! let (v1, dv1) = eval_f_and_df(3.0);
61//! assert_eq!(v1, 15.0);   // f(3) = 15
62//! assert_eq!(dv1, 8.0);   // f'(3) = 8
63//!
64//! let (v2, dv2) = eval_f_and_df(5.0);
65//! assert_eq!(v2, 35.0);   // f(5) = 35
66//! assert_eq!(dv2, 12.0);  // f'(5) = 12
67//! ```
68//!
69//! # Supported Operations
70//!
71//! - **Arithmetic**: `+`, `-`, `*`, `/`, negation
72//! - **Transcendental**: `exp`, `ln`, `sin`, `cos`, `sqrt`
73//! - Derivatives propagate automatically; chain rule emerges from composition
74//!
75//! # Use Cases
76//!
77//! - Computing derivatives of scalar functions
78//! - Gradient-based optimization and neural networks
79//! - Sensitivity analysis
80//! - Physics simulations requiring derivatives
81
82use num_traits::{Float, One, Zero};
83use std::ops::{Add, Div, Mul, Neg, Sub};
84
85/// A dual number representing a value and its derivative.
86///
87/// `Dual(value, deriv)` represents `value + deriv·ε` where `ε² = 0`.
88/// Arithmetic operations follow the algebraic rules of dual numbers;
89/// derivatives propagate automatically.
90///
91/// # Type Parameter
92///
93/// - `T`: The numeric type (typically `f64` or `f32`)
94///
95/// # Examples
96///
97/// ## Basic Usage
98///
99/// ```
100/// use autodiff::Dual;
101///
102/// let x = Dual::variable(5.0);
103/// let y = x * x;  // y = x²
104///
105/// assert_eq!(y.value, 25.0);  // 5² = 25
106/// assert_eq!(y.deriv, 10.0);  // d/dx(x²) at x=5 is 2*5 = 10
107/// ```
108///
109/// ## Chain Rule
110///
111/// ```
112/// use autodiff::Dual;
113///
114/// // f(x) = (x + 1) * (x + 2)
115/// let x = Dual::variable(3.0);
116/// let f = (x + Dual::constant(1.0)) * (x + Dual::constant(2.0));
117///
118/// assert_eq!(f.value, 20.0);  // (3+1)*(3+2) = 4*5 = 20
119/// assert_eq!(f.deriv, 9.0);   // f'(x) = 2x+3, f'(3) = 9
120/// ```
121///
122/// ## Multiple Operations
123///
124/// ```
125/// use autodiff::Dual;
126///
127/// // f(x) = x³ - 2x + 1
128/// let x = Dual::variable(2.0);
129/// let x2 = x * x;
130/// let x3 = x2 * x;
131/// let f = x3 - Dual::constant(2.0) * x + Dual::constant(1.0);
132///
133/// assert_eq!(f.value, 5.0);   // 8 - 4 + 1 = 5
134/// assert_eq!(f.deriv, 10.0);  // f'(x) = 3x² - 2, f'(2) = 12 - 2 = 10
135/// ```
136#[derive(Debug, Clone, Copy, PartialEq)]
137pub struct Dual<T> {
138    /// The primal value
139    pub value: T,
140    /// The derivative (tangent)
141    pub deriv: T,
142}
143
144impl<T> Dual<T> {
145    /// Create a new dual number with explicit value and derivative.
146    ///
147    /// # Example
148    ///
149    /// ```
150    /// use autodiff::Dual;
151    ///
152    /// let d = Dual::new(3.0, 1.0);
153    /// assert_eq!(d.value, 3.0);
154    /// assert_eq!(d.deriv, 1.0);
155    /// ```
156    pub fn new(value: T, deriv: T) -> Self {
157        Dual { value, deriv }
158    }
159
160    /// Create a constant (derivative = 0).
161    ///
162    /// Use this for literal values in your computation.
163    ///
164    /// # Example
165    ///
166    /// ```
167    /// use autodiff::Dual;
168    ///
169    /// let c = Dual::constant(5.0);
170    /// assert_eq!(c.value, 5.0);
171    /// assert_eq!(c.deriv, 0.0);
172    /// ```
173    pub fn constant(value: T) -> Self
174    where
175        T: Zero,
176    {
177        Dual {
178            value,
179            deriv: T::zero(),
180        }
181    }
182
183    /// Create a variable (derivative = 1).
184    ///
185    /// Use this for the input variable you're differentiating with
186    /// respect to.
187    ///
188    /// # Example
189    ///
190    /// ```
191    /// use autodiff::Dual;
192    ///
193    /// let x = Dual::variable(3.0);
194    /// assert_eq!(x.value, 3.0);
195    /// assert_eq!(x.deriv, 1.0);  // dx/dx = 1
196    /// ```
197    pub fn variable(value: T) -> Self
198    where
199        T: One,
200    {
201        Dual {
202            value,
203            deriv: T::one(),
204        }
205    }
206
207    /// Reciprocal (multiplicative inverse).
208    ///
209    /// For `g = b + b′·ε` with `b ≠ 0`, computes:
210    ///
211    /// `g⁻¹ = (1/b) + (-b′/b²)·ε`
212    ///
213    /// This encodes the derivative rule: `(1/g)′ = -g′/g²`.
214    ///
215    /// # Example
216    ///
217    /// ```
218    /// use autodiff::Dual;
219    ///
220    /// // f(x) = 1/x at x=2
221    /// let x = Dual::variable(2.0);
222    /// let f = x.recip();
223    ///
224    /// assert_eq!(f.value, 0.5);      // 1/2 = 0.5
225    /// assert_eq!(f.deriv, -0.25);    // d/dx(1/x) at x=2 is -1/4
226    /// ```
227    pub fn recip(self) -> Self
228    where
229        T: One + Div<Output = T> + Mul<Output = T> + Neg<Output = T> + Clone,
230    {
231        let b = self.value.clone();
232        let b_squared = b.clone() * b.clone();
233
234        Dual {
235            value: T::one() / b.clone(),
236            deriv: -(self.deriv / b_squared),
237        }
238    }
239
240    /// Exponential function.
241    ///
242    /// For `f = a + a′·ε`, computes `e^f = e^a + (a′·e^a)·ε`.
243    ///
244    /// This encodes the derivative: `d/dx(e^f) = f′·e^f`.
245    ///
246    /// # Example
247    ///
248    /// ```
249    /// use autodiff::Dual;
250    ///
251    /// // f(x) = e^x at x=0
252    /// let x = Dual::variable(0.0);
253    /// let f = x.exp();
254    ///
255    /// assert_eq!(f.value, 1.0);      // e^0 = 1
256    /// assert_eq!(f.deriv, 1.0);      // d/dx(e^x) at x=0 is e^0 = 1
257    /// ```
258    pub fn exp(self) -> Self
259    where
260        T: Float,
261    {
262        let exp_val = self.value.exp();
263        Dual {
264            value: exp_val,
265            deriv: self.deriv * exp_val,
266        }
267    }
268
269    /// Natural logarithm.
270    ///
271    /// For `f = a + a′·ε`, computes `ln(f) = ln(a) + (a′/a)·ε`.
272    ///
273    /// This encodes the derivative: `d/dx(ln f) = f′/f`.
274    ///
275    /// # Example
276    ///
277    /// ```
278    /// use autodiff::Dual;
279    ///
280    /// // f(x) = ln(x) at x=1
281    /// let x = Dual::variable(1.0);
282    /// let f = x.ln();
283    ///
284    /// assert!((f.value - 0.0_f64).abs() < 1e-12);   // ln(1) = 0
285    /// assert!((f.deriv - 1.0_f64).abs() < 1e-12);   // d/dx(ln x) at x=1 is 1/1 = 1
286    /// ```
287    pub fn ln(self) -> Self
288    where
289        T: Float,
290    {
291        Dual {
292            value: self.value.ln(),
293            deriv: self.deriv / self.value,
294        }
295    }
296
297    /// Sine function.
298    ///
299    /// For `f = a + a′·ε`, computes `sin(f) = sin(a) + (a′·cos(a))·ε`.
300    ///
301    /// This encodes the derivative: `d/dx(sin f) = f′·cos f`.
302    ///
303    /// # Example
304    ///
305    /// ```
306    /// use autodiff::Dual;
307    ///
308    /// // f(x) = sin(x) at x=0
309    /// let x = Dual::variable(0.0);
310    /// let f = x.sin();
311    ///
312    /// assert!((f.value - 0.0_f64).abs() < 1e-12);   // sin(0) = 0
313    /// assert!((f.deriv - 1.0_f64).abs() < 1e-12);   // d/dx(sin x) at x=0 is cos(0) = 1
314    /// ```
315    pub fn sin(self) -> Self
316    where
317        T: Float,
318    {
319        Dual {
320            value: self.value.sin(),
321            deriv: self.deriv * self.value.cos(),
322        }
323    }
324
325    /// Cosine function.
326    ///
327    /// For `f = a + a′·ε`, computes `cos(f) = cos(a) + (-a′·sin(a))·ε`.
328    ///
329    /// This encodes the derivative: `d/dx(cos f) = -f′·sin f`.
330    ///
331    /// # Example
332    ///
333    /// ```
334    /// use autodiff::Dual;
335    ///
336    /// // f(x) = cos(x) at x=0
337    /// let x = Dual::variable(0.0);
338    /// let f = x.cos();
339    ///
340    /// assert!((f.value - 1.0_f64).abs() < 1e-12);   // cos(0) = 1
341    /// assert!((f.deriv - 0.0_f64).abs() < 1e-12);   // d/dx(cos x) at x=0 is -sin(0) = 0
342    /// ```
343    pub fn cos(self) -> Self
344    where
345        T: Float,
346    {
347        Dual {
348            value: self.value.cos(),
349            deriv: -self.deriv * self.value.sin(),
350        }
351    }
352
353    /// Square root.
354    ///
355    /// For `f = a + a′·ε`, computes `√f = √a + (a′/(2√a))·ε`.
356    ///
357    /// This encodes the derivative: `d/dx(√f) = f′/(2√f)`.
358    ///
359    /// # Example
360    ///
361    /// ```
362    /// use autodiff::Dual;
363    ///
364    /// // f(x) = √x at x=4
365    /// let x = Dual::variable(4.0);
366    /// let f = x.sqrt();
367    ///
368    /// assert_eq!(f.value, 2.0);      // √4 = 2
369    /// assert_eq!(f.deriv, 0.25);     // d/dx(√x) at x=4 is 1/(2*2) = 0.25
370    /// ```
371    pub fn sqrt(self) -> Self
372    where
373        T: Float,
374    {
375        let sqrt_val = self.value.sqrt();
376        Dual {
377            value: sqrt_val,
378            deriv: self.deriv / (sqrt_val + sqrt_val),
379        }
380    }
381}
382
383/// Addition: (a + a′·ε) + (b + b′·ε) = (a+b) + (a′+b′)·ε
384impl<T: Add<Output = T>> Add for Dual<T> {
385    type Output = Dual<T>;
386
387    fn add(self, rhs: Self) -> Self::Output {
388        Dual {
389            value: self.value + rhs.value,
390            deriv: self.deriv + rhs.deriv,
391        }
392    }
393}
394
395/// Subtraction: (a + a′·ε) - (b + b′·ε) = (a-b) + (a′-b′)·ε
396impl<T: Sub<Output = T>> Sub for Dual<T> {
397    type Output = Dual<T>;
398
399    fn sub(self, rhs: Self) -> Self::Output {
400        Dual {
401            value: self.value - rhs.value,
402            deriv: self.deriv - rhs.deriv,
403        }
404    }
405}
406
407/// Multiplication: (a + a′·ε) * (b + b′·ε) = ab + (a′b + ab′)·ε
408///
409/// This implements the product rule: d/dx(f·g) = f′·g + f·g′
410impl<T: Mul<Output = T> + Add<Output = T> + Clone> Mul for Dual<T> {
411    type Output = Dual<T>;
412
413    fn mul(self, rhs: Self) -> Self::Output {
414        Dual {
415            value: self.value.clone() * rhs.value.clone(),
416            // Product rule: f′·g + f·g′
417            deriv: self.deriv * rhs.value + self.value * rhs.deriv,
418        }
419    }
420}
421
422/// Division: `f / g = f * (1/g)`.
423///
424/// The quotient rule emerges automatically from the product rule
425/// (in `Mul`) composed with the reciprocal rule (in `recip`).
426#[allow(clippy::suspicious_arithmetic_impl)]
427impl<T> Div for Dual<T>
428where
429    T: One + Div<Output = T> + Mul<Output = T> + Add<Output = T> + Neg<Output = T> + Clone,
430{
431    type Output = Dual<T>;
432
433    fn div(self, rhs: Self) -> Self::Output {
434        self * rhs.recip()
435    }
436}
437
438/// Negation: -(a + a′·ε) = -a + (-a′)·ε
439impl<T: Neg<Output = T>> Neg for Dual<T> {
440    type Output = Dual<T>;
441
442    fn neg(self) -> Self::Output {
443        Dual {
444            value: -self.value,
445            deriv: -self.deriv,
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn constant_has_zero_derivative() {
456        let c = Dual::constant(5.0);
457        assert_eq!(c.value, 5.0);
458        assert_eq!(c.deriv, 0.0);
459    }
460
461    #[test]
462    fn variable_has_unit_derivative() {
463        let x = Dual::variable(3.0);
464        assert_eq!(x.value, 3.0);
465        assert_eq!(x.deriv, 1.0);
466    }
467
468    #[test]
469    fn addition_works() {
470        let x = Dual::variable(3.0);
471        let c = Dual::constant(5.0);
472        let y = x + c;
473
474        assert_eq!(y.value, 8.0);
475        assert_eq!(y.deriv, 1.0); // d/dx(x + 5) = 1
476    }
477
478    #[test]
479    fn multiplication_implements_product_rule() {
480        // f(x) = x * x = x²
481        let x = Dual::variable(3.0);
482        let y = x * x;
483
484        assert_eq!(y.value, 9.0);
485        assert_eq!(y.deriv, 6.0); // d/dx(x²) at x=3 is 2*3 = 6
486    }
487
488    #[test]
489    fn recip_implements_inverse_rule() {
490        // f(x) = 1/x at x=2
491        let x = Dual::variable(2.0);
492        let y = x.recip();
493
494        assert_eq!(y.value, 0.5);
495        assert_eq!(y.deriv, -0.25); // d/dx(1/x) at x=2 is -1/4
496    }
497
498    #[test]
499    fn division_via_recip() {
500        // f(x) = 1/x at x=2, using division operator
501        let x = Dual::variable(2.0);
502        let one = Dual::constant(1.0);
503        let y = one / x;
504
505        assert_eq!(y.value, 0.5);
506        assert_eq!(y.deriv, -0.25); // d/dx(1/x) at x=2 is -1/4
507    }
508
509    #[test]
510    fn division_quotient_rule() {
511        // f(x) = (x+1)/(x+2) at x=3
512        // f(x) = 4/5 = 0.8
513        // f'(x) = [(x+2) - (x+1)]/(x+2)² = 1/(x+2)² = 1/25 = 0.04
514        let x = Dual::variable(3.0);
515        let num = x + Dual::constant(1.0);
516        let den = x + Dual::constant(2.0);
517        let y = num / den;
518
519        assert_eq!(y.value, 0.8);
520        assert!((y.deriv - 0.04_f64).abs() < 1e-10); // floating point tolerance
521    }
522
523    #[test]
524    fn chain_rule_example() {
525        // f(x) = (x + 1) * (x + 2)
526        let x = Dual::variable(3.0);
527        let f = (x + Dual::constant(1.0)) * (x + Dual::constant(2.0));
528
529        assert_eq!(f.value, 20.0); // (3+1)*(3+2) = 20
530        assert_eq!(f.deriv, 9.0); // f'(x) = 2x+3, f'(3) = 9
531    }
532
533    #[test]
534    fn polynomial_example() {
535        // f(x) = x³ - 2x + 1 at x=2
536        let x = Dual::variable(2.0);
537        let x2 = x * x;
538        let x3 = x2 * x;
539        let f = x3 - Dual::constant(2.0) * x + Dual::constant(1.0);
540
541        assert_eq!(f.value, 5.0); // 8 - 4 + 1 = 5
542        assert_eq!(f.deriv, 10.0); // f'(x) = 3x² - 2, f'(2) = 10
543    }
544
545    #[test]
546    fn negation_works() {
547        let x = Dual::variable(3.0);
548        let y = -x;
549
550        assert_eq!(y.value, -3.0);
551        assert_eq!(y.deriv, -1.0);
552    }
553
554    #[test]
555    fn subtraction_works() {
556        let x = Dual::variable(5.0);
557        let c = Dual::constant(2.0);
558        let y = x - c;
559
560        assert_eq!(y.value, 3.0);
561        assert_eq!(y.deriv, 1.0); // d/dx(x - 2) = 1
562    }
563
564    #[test]
565    fn exp_works() {
566        // f(x) = e^x at x=0
567        let x = Dual::variable(0.0);
568        let f = x.exp();
569
570        assert_eq!(f.value, 1.0); // e^0 = 1
571        assert_eq!(f.deriv, 1.0); // d/dx(e^x) at x=0 is e^0 = 1
572    }
573
574    #[test]
575    fn ln_works() {
576        // f(x) = ln(x) at x=1
577        let x = Dual::variable(1.0);
578        let f = x.ln();
579
580        assert_eq!(f.value, 0.0); // ln(1) = 0
581        assert_eq!(f.deriv, 1.0); // d/dx(ln x) at x=1 is 1/1 = 1
582    }
583
584    #[test]
585    fn sin_works() {
586        // f(x) = sin(x) at x=0
587        let x = Dual::variable(0.0);
588        let f = x.sin();
589
590        assert_eq!(f.value, 0.0); // sin(0) = 0
591        assert_eq!(f.deriv, 1.0); // d/dx(sin x) at x=0 is cos(0) = 1
592    }
593
594    #[test]
595    fn cos_works() {
596        // f(x) = cos(x) at x=0
597        let x = Dual::variable(0.0);
598        let f = x.cos();
599
600        assert_eq!(f.value, 1.0); // cos(0) = 1
601        assert_eq!(f.deriv, 0.0); // d/dx(cos x) at x=0 is -sin(0) = 0
602    }
603
604    #[test]
605    fn sqrt_works() {
606        // f(x) = √x at x=4
607        let x = Dual::variable(4.0);
608        let f = x.sqrt();
609
610        assert_eq!(f.value, 2.0); // √4 = 2
611        assert_eq!(f.deriv, 0.25); // d/dx(√x) at x=4 is 1/(2*2) = 0.25
612    }
613
614    #[test]
615    fn chain_rule_with_transcendentals() {
616        // f(x) = sin(2x) at x=0
617        // f'(x) = 2*cos(2x), f'(0) = 2*1 = 2
618        let x = Dual::variable(0.0);
619        let two_x = Dual::constant(2.0) * x;
620        let f = two_x.sin();
621
622        assert_eq!(f.value, 0.0);
623        assert_eq!(f.deriv, 2.0);
624    }
625
626    #[test]
627    fn exp_of_polynomial() {
628        // f(x) = e^(x²) at x=1
629        // f'(x) = 2x * e^(x²), f'(1) = 2 * e
630        let x = Dual::variable(1.0);
631        let x_squared = x * x;
632        let f = x_squared.exp();
633
634        let e = 1.0_f64.exp();
635        assert_eq!(f.value, e);
636        assert!((f.deriv - 2.0 * e).abs() < 1e-10);
637    }
638
639    #[test]
640    fn reusable_function_pattern() {
641        // Define a function once
642        fn f(x: Dual<f64>) -> Dual<f64> {
643            x * x + Dual::constant(2.0) * x
644        }
645
646        // Helper to evaluate f and f' at a point
647        fn eval_f_and_df(a: f64) -> (f64, f64) {
648            let y = f(Dual::variable(a));
649            (y.value, y.deriv)
650        }
651
652        // Evaluate at multiple points
653        let (v1, dv1) = eval_f_and_df(3.0);
654        assert_eq!(v1, 15.0); // f(3) = 9 + 6 = 15
655        assert_eq!(dv1, 8.0); // f'(3) = 2*3 + 2 = 8
656
657        let (v2, dv2) = eval_f_and_df(5.0);
658        assert_eq!(v2, 35.0); // f(5) = 25 + 10 = 35
659        assert_eq!(dv2, 12.0); // f'(5) = 2*5 + 2 = 12
660
661        let (v3, dv3) = eval_f_and_df(0.0);
662        assert_eq!(v3, 0.0); // f(0) = 0
663        assert_eq!(dv3, 2.0); // f'(0) = 2
664    }
665}