autodiff/dual.rs
1//! Dual numbers for forward-mode automatic differentiation.
2//!
3//! A dual number represents a value and its derivative simultaneously,
4//! enabling automatic computation of derivatives through operator
5//! overloading.
6//!
7//! # Mathematical Background
8//!
9//! A dual number has the form `a + a′·ε` where `ε² = 0` (and `a′` denotes
10//! the derivative with respect to the seeded input). Arithmetic operations
11//! on dual numbers follow these algebraic rules:
12//!
13//! - `(a + a′·ε) + (b + b′·ε) = (a+b) + (a′+b′)·ε`
14//! - `-(a + a′·ε) = -a + (-a′)·ε`
15//! - `(a + a′·ε) - (b + b′·ε) = (a-b) + (a′-b′)·ε`
16//! - `(a + a′·ε) * (b + b′·ε) = ab + (a′b + ab′)·ε`
17//! - `1/(b + b′·ε) = (1/b) + (-b′/b²)·ε`
18//! - `(a + a′·ε) / (b + b′·ε) = (a + a′·ε) * (1/(b + b′·ε))`
19//!
20//! The chain rule emerges implicitly from composing these operations—you
21//! never write it down explicitly.
22//!
23//! This is **forward-mode** automatic differentiation: we compute the
24//! derivative as we compute the function value.
25//!
26//! # Example
27//!
28//! ```
29//! use autodiff::Dual;
30//!
31//! // Compute f(x) = x² + 2x at x=3
32//! let x = Dual::variable(3.0); // x with derivative dx/dx = 1
33//!
34//! let f = x * x + Dual::constant(2.0) * x;
35//!
36//! assert_eq!(f.value, 15.0); // f(3) = 9 + 6 = 15
37//! assert_eq!(f.deriv, 8.0); // f'(3) = 2*3 + 2 = 8
38//! ```
39//!
40//! # Reusable Functions
41//!
42//! The idiomatic way to compute derivatives at multiple points is to
43//! write the function once and evaluate it with different seeds:
44//!
45//! ```
46//! use autodiff::Dual;
47//!
48//! // Define the function once
49//! fn f(x: Dual<f64>) -> Dual<f64> {
50//! x * x + Dual::constant(2.0) * x
51//! }
52//!
53//! // Helper to evaluate f and f' at a point
54//! fn eval_f_and_df(a: f64) -> (f64, f64) {
55//! let y = f(Dual::variable(a));
56//! (y.value, y.deriv)
57//! }
58//!
59//! // Evaluate at multiple points without rewriting the expression
60//! let (v1, dv1) = eval_f_and_df(3.0);
61//! assert_eq!(v1, 15.0); // f(3) = 15
62//! assert_eq!(dv1, 8.0); // f'(3) = 8
63//!
64//! let (v2, dv2) = eval_f_and_df(5.0);
65//! assert_eq!(v2, 35.0); // f(5) = 35
66//! assert_eq!(dv2, 12.0); // f'(5) = 12
67//! ```
68//!
69//! # Supported Operations
70//!
71//! - **Arithmetic**: `+`, `-`, `*`, `/`, negation
72//! - **Transcendental**: `exp`, `ln`, `sin`, `cos`, `sqrt`
73//! - Derivatives propagate automatically; chain rule emerges from composition
74//!
75//! # Use Cases
76//!
77//! - Computing derivatives of scalar functions
78//! - Gradient-based optimization and neural networks
79//! - Sensitivity analysis
80//! - Physics simulations requiring derivatives
81
82use num_traits::{Float, One, Zero};
83use std::ops::{Add, Div, Mul, Neg, Sub};
84
85/// A dual number representing a value and its derivative.
86///
87/// `Dual(value, deriv)` represents `value + deriv·ε` where `ε² = 0`.
88/// Arithmetic operations follow the algebraic rules of dual numbers;
89/// derivatives propagate automatically.
90///
91/// # Type Parameter
92///
93/// - `T`: The numeric type (typically `f64` or `f32`)
94///
95/// # Examples
96///
97/// ## Basic Usage
98///
99/// ```
100/// use autodiff::Dual;
101///
102/// let x = Dual::variable(5.0);
103/// let y = x * x; // y = x²
104///
105/// assert_eq!(y.value, 25.0); // 5² = 25
106/// assert_eq!(y.deriv, 10.0); // d/dx(x²) at x=5 is 2*5 = 10
107/// ```
108///
109/// ## Chain Rule
110///
111/// ```
112/// use autodiff::Dual;
113///
114/// // f(x) = (x + 1) * (x + 2)
115/// let x = Dual::variable(3.0);
116/// let f = (x + Dual::constant(1.0)) * (x + Dual::constant(2.0));
117///
118/// assert_eq!(f.value, 20.0); // (3+1)*(3+2) = 4*5 = 20
119/// assert_eq!(f.deriv, 9.0); // f'(x) = 2x+3, f'(3) = 9
120/// ```
121///
122/// ## Multiple Operations
123///
124/// ```
125/// use autodiff::Dual;
126///
127/// // f(x) = x³ - 2x + 1
128/// let x = Dual::variable(2.0);
129/// let x2 = x * x;
130/// let x3 = x2 * x;
131/// let f = x3 - Dual::constant(2.0) * x + Dual::constant(1.0);
132///
133/// assert_eq!(f.value, 5.0); // 8 - 4 + 1 = 5
134/// assert_eq!(f.deriv, 10.0); // f'(x) = 3x² - 2, f'(2) = 12 - 2 = 10
135/// ```
136#[derive(Debug, Clone, Copy, PartialEq)]
137pub struct Dual<T> {
138 /// The primal value
139 pub value: T,
140 /// The derivative (tangent)
141 pub deriv: T,
142}
143
144impl<T> Dual<T> {
145 /// Create a new dual number with explicit value and derivative.
146 ///
147 /// # Example
148 ///
149 /// ```
150 /// use autodiff::Dual;
151 ///
152 /// let d = Dual::new(3.0, 1.0);
153 /// assert_eq!(d.value, 3.0);
154 /// assert_eq!(d.deriv, 1.0);
155 /// ```
156 pub fn new(value: T, deriv: T) -> Self {
157 Dual { value, deriv }
158 }
159
160 /// Create a constant (derivative = 0).
161 ///
162 /// Use this for literal values in your computation.
163 ///
164 /// # Example
165 ///
166 /// ```
167 /// use autodiff::Dual;
168 ///
169 /// let c = Dual::constant(5.0);
170 /// assert_eq!(c.value, 5.0);
171 /// assert_eq!(c.deriv, 0.0);
172 /// ```
173 pub fn constant(value: T) -> Self
174 where
175 T: Zero,
176 {
177 Dual {
178 value,
179 deriv: T::zero(),
180 }
181 }
182
183 /// Create a variable (derivative = 1).
184 ///
185 /// Use this for the input variable you're differentiating with
186 /// respect to.
187 ///
188 /// # Example
189 ///
190 /// ```
191 /// use autodiff::Dual;
192 ///
193 /// let x = Dual::variable(3.0);
194 /// assert_eq!(x.value, 3.0);
195 /// assert_eq!(x.deriv, 1.0); // dx/dx = 1
196 /// ```
197 pub fn variable(value: T) -> Self
198 where
199 T: One,
200 {
201 Dual {
202 value,
203 deriv: T::one(),
204 }
205 }
206
207 /// Reciprocal (multiplicative inverse).
208 ///
209 /// For `g = b + b′·ε` with `b ≠ 0`, computes:
210 ///
211 /// `g⁻¹ = (1/b) + (-b′/b²)·ε`
212 ///
213 /// This encodes the derivative rule: `(1/g)′ = -g′/g²`.
214 ///
215 /// # Example
216 ///
217 /// ```
218 /// use autodiff::Dual;
219 ///
220 /// // f(x) = 1/x at x=2
221 /// let x = Dual::variable(2.0);
222 /// let f = x.recip();
223 ///
224 /// assert_eq!(f.value, 0.5); // 1/2 = 0.5
225 /// assert_eq!(f.deriv, -0.25); // d/dx(1/x) at x=2 is -1/4
226 /// ```
227 pub fn recip(self) -> Self
228 where
229 T: One + Div<Output = T> + Mul<Output = T> + Neg<Output = T> + Clone,
230 {
231 let b = self.value.clone();
232 let b_squared = b.clone() * b.clone();
233
234 Dual {
235 value: T::one() / b.clone(),
236 deriv: -(self.deriv / b_squared),
237 }
238 }
239
240 /// Exponential function.
241 ///
242 /// For `f = a + a′·ε`, computes `e^f = e^a + (a′·e^a)·ε`.
243 ///
244 /// This encodes the derivative: `d/dx(e^f) = f′·e^f`.
245 ///
246 /// # Example
247 ///
248 /// ```
249 /// use autodiff::Dual;
250 ///
251 /// // f(x) = e^x at x=0
252 /// let x = Dual::variable(0.0);
253 /// let f = x.exp();
254 ///
255 /// assert_eq!(f.value, 1.0); // e^0 = 1
256 /// assert_eq!(f.deriv, 1.0); // d/dx(e^x) at x=0 is e^0 = 1
257 /// ```
258 pub fn exp(self) -> Self
259 where
260 T: Float,
261 {
262 let exp_val = self.value.exp();
263 Dual {
264 value: exp_val,
265 deriv: self.deriv * exp_val,
266 }
267 }
268
269 /// Natural logarithm.
270 ///
271 /// For `f = a + a′·ε`, computes `ln(f) = ln(a) + (a′/a)·ε`.
272 ///
273 /// This encodes the derivative: `d/dx(ln f) = f′/f`.
274 ///
275 /// # Example
276 ///
277 /// ```
278 /// use autodiff::Dual;
279 ///
280 /// // f(x) = ln(x) at x=1
281 /// let x = Dual::variable(1.0);
282 /// let f = x.ln();
283 ///
284 /// assert!((f.value - 0.0_f64).abs() < 1e-12); // ln(1) = 0
285 /// assert!((f.deriv - 1.0_f64).abs() < 1e-12); // d/dx(ln x) at x=1 is 1/1 = 1
286 /// ```
287 pub fn ln(self) -> Self
288 where
289 T: Float,
290 {
291 Dual {
292 value: self.value.ln(),
293 deriv: self.deriv / self.value,
294 }
295 }
296
297 /// Sine function.
298 ///
299 /// For `f = a + a′·ε`, computes `sin(f) = sin(a) + (a′·cos(a))·ε`.
300 ///
301 /// This encodes the derivative: `d/dx(sin f) = f′·cos f`.
302 ///
303 /// # Example
304 ///
305 /// ```
306 /// use autodiff::Dual;
307 ///
308 /// // f(x) = sin(x) at x=0
309 /// let x = Dual::variable(0.0);
310 /// let f = x.sin();
311 ///
312 /// assert!((f.value - 0.0_f64).abs() < 1e-12); // sin(0) = 0
313 /// assert!((f.deriv - 1.0_f64).abs() < 1e-12); // d/dx(sin x) at x=0 is cos(0) = 1
314 /// ```
315 pub fn sin(self) -> Self
316 where
317 T: Float,
318 {
319 Dual {
320 value: self.value.sin(),
321 deriv: self.deriv * self.value.cos(),
322 }
323 }
324
325 /// Cosine function.
326 ///
327 /// For `f = a + a′·ε`, computes `cos(f) = cos(a) + (-a′·sin(a))·ε`.
328 ///
329 /// This encodes the derivative: `d/dx(cos f) = -f′·sin f`.
330 ///
331 /// # Example
332 ///
333 /// ```
334 /// use autodiff::Dual;
335 ///
336 /// // f(x) = cos(x) at x=0
337 /// let x = Dual::variable(0.0);
338 /// let f = x.cos();
339 ///
340 /// assert!((f.value - 1.0_f64).abs() < 1e-12); // cos(0) = 1
341 /// assert!((f.deriv - 0.0_f64).abs() < 1e-12); // d/dx(cos x) at x=0 is -sin(0) = 0
342 /// ```
343 pub fn cos(self) -> Self
344 where
345 T: Float,
346 {
347 Dual {
348 value: self.value.cos(),
349 deriv: -self.deriv * self.value.sin(),
350 }
351 }
352
353 /// Square root.
354 ///
355 /// For `f = a + a′·ε`, computes `√f = √a + (a′/(2√a))·ε`.
356 ///
357 /// This encodes the derivative: `d/dx(√f) = f′/(2√f)`.
358 ///
359 /// # Example
360 ///
361 /// ```
362 /// use autodiff::Dual;
363 ///
364 /// // f(x) = √x at x=4
365 /// let x = Dual::variable(4.0);
366 /// let f = x.sqrt();
367 ///
368 /// assert_eq!(f.value, 2.0); // √4 = 2
369 /// assert_eq!(f.deriv, 0.25); // d/dx(√x) at x=4 is 1/(2*2) = 0.25
370 /// ```
371 pub fn sqrt(self) -> Self
372 where
373 T: Float,
374 {
375 let sqrt_val = self.value.sqrt();
376 Dual {
377 value: sqrt_val,
378 deriv: self.deriv / (sqrt_val + sqrt_val),
379 }
380 }
381}
382
383/// Addition: (a + a′·ε) + (b + b′·ε) = (a+b) + (a′+b′)·ε
384impl<T: Add<Output = T>> Add for Dual<T> {
385 type Output = Dual<T>;
386
387 fn add(self, rhs: Self) -> Self::Output {
388 Dual {
389 value: self.value + rhs.value,
390 deriv: self.deriv + rhs.deriv,
391 }
392 }
393}
394
395/// Subtraction: (a + a′·ε) - (b + b′·ε) = (a-b) + (a′-b′)·ε
396impl<T: Sub<Output = T>> Sub for Dual<T> {
397 type Output = Dual<T>;
398
399 fn sub(self, rhs: Self) -> Self::Output {
400 Dual {
401 value: self.value - rhs.value,
402 deriv: self.deriv - rhs.deriv,
403 }
404 }
405}
406
407/// Multiplication: (a + a′·ε) * (b + b′·ε) = ab + (a′b + ab′)·ε
408///
409/// This implements the product rule: d/dx(f·g) = f′·g + f·g′
410impl<T: Mul<Output = T> + Add<Output = T> + Clone> Mul for Dual<T> {
411 type Output = Dual<T>;
412
413 fn mul(self, rhs: Self) -> Self::Output {
414 Dual {
415 value: self.value.clone() * rhs.value.clone(),
416 // Product rule: f′·g + f·g′
417 deriv: self.deriv * rhs.value + self.value * rhs.deriv,
418 }
419 }
420}
421
422/// Division: `f / g = f * (1/g)`.
423///
424/// The quotient rule emerges automatically from the product rule
425/// (in `Mul`) composed with the reciprocal rule (in `recip`).
426#[allow(clippy::suspicious_arithmetic_impl)]
427impl<T> Div for Dual<T>
428where
429 T: One + Div<Output = T> + Mul<Output = T> + Add<Output = T> + Neg<Output = T> + Clone,
430{
431 type Output = Dual<T>;
432
433 fn div(self, rhs: Self) -> Self::Output {
434 self * rhs.recip()
435 }
436}
437
438/// Negation: -(a + a′·ε) = -a + (-a′)·ε
439impl<T: Neg<Output = T>> Neg for Dual<T> {
440 type Output = Dual<T>;
441
442 fn neg(self) -> Self::Output {
443 Dual {
444 value: -self.value,
445 deriv: -self.deriv,
446 }
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn constant_has_zero_derivative() {
456 let c = Dual::constant(5.0);
457 assert_eq!(c.value, 5.0);
458 assert_eq!(c.deriv, 0.0);
459 }
460
461 #[test]
462 fn variable_has_unit_derivative() {
463 let x = Dual::variable(3.0);
464 assert_eq!(x.value, 3.0);
465 assert_eq!(x.deriv, 1.0);
466 }
467
468 #[test]
469 fn addition_works() {
470 let x = Dual::variable(3.0);
471 let c = Dual::constant(5.0);
472 let y = x + c;
473
474 assert_eq!(y.value, 8.0);
475 assert_eq!(y.deriv, 1.0); // d/dx(x + 5) = 1
476 }
477
478 #[test]
479 fn multiplication_implements_product_rule() {
480 // f(x) = x * x = x²
481 let x = Dual::variable(3.0);
482 let y = x * x;
483
484 assert_eq!(y.value, 9.0);
485 assert_eq!(y.deriv, 6.0); // d/dx(x²) at x=3 is 2*3 = 6
486 }
487
488 #[test]
489 fn recip_implements_inverse_rule() {
490 // f(x) = 1/x at x=2
491 let x = Dual::variable(2.0);
492 let y = x.recip();
493
494 assert_eq!(y.value, 0.5);
495 assert_eq!(y.deriv, -0.25); // d/dx(1/x) at x=2 is -1/4
496 }
497
498 #[test]
499 fn division_via_recip() {
500 // f(x) = 1/x at x=2, using division operator
501 let x = Dual::variable(2.0);
502 let one = Dual::constant(1.0);
503 let y = one / x;
504
505 assert_eq!(y.value, 0.5);
506 assert_eq!(y.deriv, -0.25); // d/dx(1/x) at x=2 is -1/4
507 }
508
509 #[test]
510 fn division_quotient_rule() {
511 // f(x) = (x+1)/(x+2) at x=3
512 // f(x) = 4/5 = 0.8
513 // f'(x) = [(x+2) - (x+1)]/(x+2)² = 1/(x+2)² = 1/25 = 0.04
514 let x = Dual::variable(3.0);
515 let num = x + Dual::constant(1.0);
516 let den = x + Dual::constant(2.0);
517 let y = num / den;
518
519 assert_eq!(y.value, 0.8);
520 assert!((y.deriv - 0.04_f64).abs() < 1e-10); // floating point tolerance
521 }
522
523 #[test]
524 fn chain_rule_example() {
525 // f(x) = (x + 1) * (x + 2)
526 let x = Dual::variable(3.0);
527 let f = (x + Dual::constant(1.0)) * (x + Dual::constant(2.0));
528
529 assert_eq!(f.value, 20.0); // (3+1)*(3+2) = 20
530 assert_eq!(f.deriv, 9.0); // f'(x) = 2x+3, f'(3) = 9
531 }
532
533 #[test]
534 fn polynomial_example() {
535 // f(x) = x³ - 2x + 1 at x=2
536 let x = Dual::variable(2.0);
537 let x2 = x * x;
538 let x3 = x2 * x;
539 let f = x3 - Dual::constant(2.0) * x + Dual::constant(1.0);
540
541 assert_eq!(f.value, 5.0); // 8 - 4 + 1 = 5
542 assert_eq!(f.deriv, 10.0); // f'(x) = 3x² - 2, f'(2) = 10
543 }
544
545 #[test]
546 fn negation_works() {
547 let x = Dual::variable(3.0);
548 let y = -x;
549
550 assert_eq!(y.value, -3.0);
551 assert_eq!(y.deriv, -1.0);
552 }
553
554 #[test]
555 fn subtraction_works() {
556 let x = Dual::variable(5.0);
557 let c = Dual::constant(2.0);
558 let y = x - c;
559
560 assert_eq!(y.value, 3.0);
561 assert_eq!(y.deriv, 1.0); // d/dx(x - 2) = 1
562 }
563
564 #[test]
565 fn exp_works() {
566 // f(x) = e^x at x=0
567 let x = Dual::variable(0.0);
568 let f = x.exp();
569
570 assert_eq!(f.value, 1.0); // e^0 = 1
571 assert_eq!(f.deriv, 1.0); // d/dx(e^x) at x=0 is e^0 = 1
572 }
573
574 #[test]
575 fn ln_works() {
576 // f(x) = ln(x) at x=1
577 let x = Dual::variable(1.0);
578 let f = x.ln();
579
580 assert_eq!(f.value, 0.0); // ln(1) = 0
581 assert_eq!(f.deriv, 1.0); // d/dx(ln x) at x=1 is 1/1 = 1
582 }
583
584 #[test]
585 fn sin_works() {
586 // f(x) = sin(x) at x=0
587 let x = Dual::variable(0.0);
588 let f = x.sin();
589
590 assert_eq!(f.value, 0.0); // sin(0) = 0
591 assert_eq!(f.deriv, 1.0); // d/dx(sin x) at x=0 is cos(0) = 1
592 }
593
594 #[test]
595 fn cos_works() {
596 // f(x) = cos(x) at x=0
597 let x = Dual::variable(0.0);
598 let f = x.cos();
599
600 assert_eq!(f.value, 1.0); // cos(0) = 1
601 assert_eq!(f.deriv, 0.0); // d/dx(cos x) at x=0 is -sin(0) = 0
602 }
603
604 #[test]
605 fn sqrt_works() {
606 // f(x) = √x at x=4
607 let x = Dual::variable(4.0);
608 let f = x.sqrt();
609
610 assert_eq!(f.value, 2.0); // √4 = 2
611 assert_eq!(f.deriv, 0.25); // d/dx(√x) at x=4 is 1/(2*2) = 0.25
612 }
613
614 #[test]
615 fn chain_rule_with_transcendentals() {
616 // f(x) = sin(2x) at x=0
617 // f'(x) = 2*cos(2x), f'(0) = 2*1 = 2
618 let x = Dual::variable(0.0);
619 let two_x = Dual::constant(2.0) * x;
620 let f = two_x.sin();
621
622 assert_eq!(f.value, 0.0);
623 assert_eq!(f.deriv, 2.0);
624 }
625
626 #[test]
627 fn exp_of_polynomial() {
628 // f(x) = e^(x²) at x=1
629 // f'(x) = 2x * e^(x²), f'(1) = 2 * e
630 let x = Dual::variable(1.0);
631 let x_squared = x * x;
632 let f = x_squared.exp();
633
634 let e = 1.0_f64.exp();
635 assert_eq!(f.value, e);
636 assert!((f.deriv - 2.0 * e).abs() < 1e-10);
637 }
638
639 #[test]
640 fn reusable_function_pattern() {
641 // Define a function once
642 fn f(x: Dual<f64>) -> Dual<f64> {
643 x * x + Dual::constant(2.0) * x
644 }
645
646 // Helper to evaluate f and f' at a point
647 fn eval_f_and_df(a: f64) -> (f64, f64) {
648 let y = f(Dual::variable(a));
649 (y.value, y.deriv)
650 }
651
652 // Evaluate at multiple points
653 let (v1, dv1) = eval_f_and_df(3.0);
654 assert_eq!(v1, 15.0); // f(3) = 9 + 6 = 15
655 assert_eq!(dv1, 8.0); // f'(3) = 2*3 + 2 = 8
656
657 let (v2, dv2) = eval_f_and_df(5.0);
658 assert_eq!(v2, 35.0); // f(5) = 25 + 10 = 35
659 assert_eq!(dv2, 12.0); // f'(5) = 2*5 + 2 = 12
660
661 let (v3, dv3) = eval_f_and_df(0.0);
662 assert_eq!(v3, 0.0); // f(0) = 0
663 assert_eq!(dv3, 2.0); // f'(0) = 2
664 }
665}