autodiff/
multidual.rs

1//! Multi-component dual numbers for multivariable automatic
2//! differentiation.
3//!
4//! A multi-component dual number tracks a value and multiple partial
5//! derivatives simultaneously, enabling computation of gradients in a
6//! **single forward pass**.
7//!
8//! # Mathematical Background
9//!
10//! For a function f: ℝⁿ → ℝ, `MultiDual<T, N>` represents a value and
11//! its gradient ∇f = [∂f/∂x₁, ∂f/∂x₂, ..., ∂f/∂xₙ] simultaneously.
12//!
13//! Arithmetic operations extend naturally from single-variable dual
14//! numbers:
15//!
16//! - `(a + ∇a) + (b + ∇b) = (a+b) + (∇a+∇b)`
17//! - `-(a + ∇a) = -a + (-∇a)`
18//! - `(a + ∇a) - (b + ∇b) = (a-b) + (∇a-∇b)`
19//! - `(a + ∇a) * (b + ∇b) = ab + (b∇a + a∇b)`
20//! - `1/(b + ∇b) = (1/b) + (-∇b/b²)`
21//!
22//! Each operation updates all N derivative components at once,
23//! computing the full gradient in a single pass through the
24//! computation.
25//!
26//! # Example
27//!
28//! ```
29//! use autodiff::{MultiDual, gradient};
30//!
31//! // Compute ∇f for f(x, y) = x² + 2xy + y² at (3, 4)
32//! let f = |vars: [MultiDual<f64, 2>; 2]| {
33//!     let [x, y] = vars;
34//!     let two = MultiDual::constant(2.0);
35//!     x * x + two * x * y + y * y
36//! };
37//!
38//! let point = [3.0, 4.0];
39//! let (value, grad) = gradient(f, point);
40//!
41//! assert_eq!(value, 49.0);   // f(3, 4) = 9 + 24 + 16 = 49
42//! assert_eq!(grad[0], 14.0); // ∂f/∂x = 2x + 2y = 14
43//! assert_eq!(grad[1], 14.0); // ∂f/∂y = 2x + 2y = 14
44//! ```
45//!
46//! # Efficiency
47//!
48//! Computing the gradient requires **1 forward pass** with
49//! `MultiDual<T, N>`, compared to n passes if using `Dual<T>` to
50//! compute each partial derivative separately. For n=10 inputs, this
51//! is a 10x speedup.
52//!
53//! # Use Cases
54//!
55
56//! - Gradient-based optimization (gradient descent, Newton's method)
57//! - Neural network backpropagation alternatives
58//! - Sensitivity analysis with multiple parameters
59//! - Scientific computing with multivariable functions
60
61use num_traits::{Float, One, Zero};
62use std::ops::{Add, Div, Mul, Neg, Sub};
63
64/// A multi-component dual number representing a value and N partial
65/// derivatives.
66///
67/// `MultiDual<T, N>` represents a value along with its gradient
68/// [∂f/∂x₁, ..., ∂f/∂xₙ] for forward-mode automatic differentiation
69/// of multivariable functions.
70///
71/// # Type Parameters
72///
73/// - `T`: The numeric type (typically `f64` or `f32`)
74/// - `N`: The number of input variables (compile-time constant)
75///
76/// # Examples
77///
78/// ## Creating Variables
79///
80/// ```
81/// use autodiff::MultiDual;
82///
83/// // Create the first variable x with value 3.0 (∂/∂x = 1, ∂/∂y = 0)
84/// let x = MultiDual::<f64, 2>::variable(3.0, 0);
85/// assert_eq!(x.value, 3.0);
86/// assert_eq!(x.derivs, [1.0, 0.0]);
87///
88/// // Create the second variable y with value 4.0 (∂/∂x = 0, ∂/∂y = 1)
89/// let y = MultiDual::<f64, 2>::variable(4.0, 1);
90/// assert_eq!(y.value, 4.0);
91/// assert_eq!(y.derivs, [0.0, 1.0]);
92///
93/// // Create a constant (∂/∂x = 0, ∂/∂y = 0)
94/// let c = MultiDual::<f64, 2>::constant(2.0);
95/// assert_eq!(c.value, 2.0);
96/// assert_eq!(c.derivs, [0.0, 0.0]);
97/// ```
98///
99/// ## Computing Gradients
100///
101/// ```
102/// use autodiff::MultiDual;
103///
104/// // f(x, y) = x² + y²
105/// let x = MultiDual::<f64, 2>::variable(3.0, 0);
106/// let y = MultiDual::<f64, 2>::variable(4.0, 1);
107///
108/// let f = x * x + y * y;
109///
110/// assert_eq!(f.value, 25.0);      // 3² + 4² = 25
111/// assert_eq!(f.derivs[0], 6.0);   // ∂f/∂x = 2x = 6
112/// assert_eq!(f.derivs[1], 8.0);   // ∂f/∂y = 2y = 8
113/// ```
114#[derive(Debug, Clone, Copy, PartialEq)]
115pub struct MultiDual<T, const N: usize> {
116    /// The primal value (function output)
117    pub value: T,
118    /// The partial derivatives [∂f/∂x₁, ∂f/∂x₂, ..., ∂f/∂xₙ]
119    pub derivs: [T; N],
120}
121
122impl<T, const N: usize> MultiDual<T, N>
123where
124    T: Copy,
125{
126    /// Create a dual number with explicit value and derivatives.
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// use autodiff::MultiDual;
132    ///
133    /// let d = MultiDual::new(5.0, [1.0, 2.0, 3.0]);
134    /// assert_eq!(d.value, 5.0);
135    /// assert_eq!(d.derivs, [1.0, 2.0, 3.0]);
136    /// ```
137    pub fn new(value: T, derivs: [T; N]) -> Self {
138        Self { value, derivs }
139    }
140
141    /// Create a constant (all partial derivatives are zero).
142    ///
143    /// # Examples
144    ///
145    /// ```
146    /// use autodiff::MultiDual;
147    ///
148    /// let c = MultiDual::<f64, 3>::constant(42.0);
149    /// assert_eq!(c.value, 42.0);
150    /// assert_eq!(c.derivs, [0.0, 0.0, 0.0]);
151    /// ```
152    pub fn constant(value: T) -> Self
153    where
154        T: Zero,
155    {
156        Self {
157            value,
158            derivs: [T::zero(); N],
159        }
160    }
161
162    /// Create the i-th input variable.
163    ///
164    /// Sets `derivs[index] = 1` and all other derivatives to zero,
165    /// representing ∂xᵢ/∂xⱼ = δᵢⱼ (Kronecker delta).
166    ///
167    /// # Panics
168    ///
169    /// Panics if `index >= N`.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use autodiff::MultiDual;
175    ///
176    /// // Create x (first variable)
177    /// let x = MultiDual::<f64, 2>::variable(3.0, 0);
178    /// assert_eq!(x.value, 3.0);
179    /// assert_eq!(x.derivs, [1.0, 0.0]);
180    ///
181    /// // Create y (second variable)
182    /// let y = MultiDual::<f64, 2>::variable(4.0, 1);
183    /// assert_eq!(y.value, 4.0);
184    /// assert_eq!(y.derivs, [0.0, 1.0]);
185    /// ```
186    pub fn variable(value: T, index: usize) -> Self
187    where
188        T: Zero + One,
189    {
190        assert!(
191            index < N,
192            "Variable index {} out of bounds for N={}",
193            index,
194            N
195        );
196        let mut derivs = [T::zero(); N];
197        derivs[index] = T::one();
198        Self { value, derivs }
199    }
200}
201
202/// Addition: `(a + ∇a) + (b + ∇b) = (a+b) + (∇a+∇b)`
203///
204/// # Examples
205///
206/// ```
207/// use autodiff::MultiDual;
208///
209/// let x = MultiDual::new(3.0, [1.0, 0.0]);
210/// let y = MultiDual::new(4.0, [0.0, 1.0]);
211/// let sum = x + y;
212///
213/// assert_eq!(sum.value, 7.0);
214/// assert_eq!(sum.derivs, [1.0, 1.0]);
215/// ```
216impl<T, const N: usize> Add for MultiDual<T, N>
217where
218    T: Add<Output = T> + Copy,
219{
220    type Output = Self;
221
222    fn add(self, rhs: Self) -> Self {
223        let mut derivs = self.derivs;
224        for (deriv, rhs_deriv) in derivs.iter_mut().zip(rhs.derivs.iter()) {
225            *deriv = *deriv + *rhs_deriv;
226        }
227        Self {
228            value: self.value + rhs.value,
229            derivs,
230        }
231    }
232}
233
234/// Subtraction: `(a + ∇a) - (b + ∇b) = (a-b) + (∇a-∇b)`
235///
236/// # Examples
237///
238/// ```
239/// use autodiff::MultiDual;
240///
241/// let x = MultiDual::new(7.0, [2.0, 3.0]);
242/// let y = MultiDual::new(4.0, [1.0, 2.0]);
243/// let diff = x - y;
244///
245/// assert_eq!(diff.value, 3.0);
246/// assert_eq!(diff.derivs, [1.0, 1.0]);
247/// ```
248impl<T, const N: usize> Sub for MultiDual<T, N>
249where
250    T: Sub<Output = T> + Copy,
251{
252    type Output = Self;
253
254    fn sub(self, rhs: Self) -> Self {
255        let mut derivs = self.derivs;
256        for (deriv, rhs_deriv) in derivs.iter_mut().zip(rhs.derivs.iter()) {
257            *deriv = *deriv - *rhs_deriv;
258        }
259        Self {
260            value: self.value - rhs.value,
261            derivs,
262        }
263    }
264}
265
266/// Multiplication: `(a + ∇a) * (b + ∇b) = ab + (b∇a + a∇b)`
267///
268/// This implements the product rule automatically.
269///
270/// # Examples
271///
272/// ```
273/// use autodiff::MultiDual;
274///
275/// // f(x, y) = x * y at (3, 4)
276/// let x = MultiDual::<f64, 2>::variable(3.0, 0);
277/// let y = MultiDual::<f64, 2>::variable(4.0, 1);
278/// let product = x * y;
279///
280/// assert_eq!(product.value, 12.0);      // 3 * 4
281/// assert_eq!(product.derivs[0], 4.0);   // ∂(xy)/∂x = y = 4
282/// assert_eq!(product.derivs[1], 3.0);   // ∂(xy)/∂y = x = 3
283/// ```
284impl<T, const N: usize> Mul for MultiDual<T, N>
285where
286    T: Mul<Output = T> + Add<Output = T> + Copy,
287{
288    type Output = Self;
289
290    fn mul(self, rhs: Self) -> Self {
291        let mut derivs = [self.value; N];
292        for (deriv, (self_deriv, rhs_deriv)) in derivs
293            .iter_mut()
294            .zip(self.derivs.iter().zip(rhs.derivs.iter()))
295        {
296            // Product rule: (f*g)' = f'*g + f*g'
297            *deriv = *self_deriv * rhs.value + self.value * *rhs_deriv;
298        }
299        Self {
300            value: self.value * rhs.value,
301            derivs,
302        }
303    }
304}
305
306/// Negation: `-(a + ∇a) = -a + (-∇a)`
307///
308/// # Examples
309///
310/// ```
311/// use autodiff::MultiDual;
312///
313/// let x = MultiDual::new(3.0, [1.0, 2.0]);
314/// let neg_x = -x;
315///
316/// assert_eq!(neg_x.value, -3.0);
317/// assert_eq!(neg_x.derivs, [-1.0, -2.0]);
318/// ```
319impl<T, const N: usize> Neg for MultiDual<T, N>
320where
321    T: Neg<Output = T> + Copy,
322{
323    type Output = Self;
324
325    fn neg(self) -> Self {
326        let mut derivs = self.derivs;
327        for deriv in &mut derivs {
328            *deriv = -*deriv;
329        }
330        Self {
331            value: -self.value,
332            derivs,
333        }
334    }
335}
336
337impl<T, const N: usize> MultiDual<T, N>
338where
339    T: Float,
340{
341    /// Reciprocal: `1/(b + ∇b) = (1/b) + (-∇b/b²)`
342    ///
343    /// Implements the rule: (1/g)′ = -g′/g² (note: g ≠ 0)
344    ///
345    /// # Examples
346    ///
347    /// ```
348    /// use autodiff::MultiDual;
349    ///
350    /// // f(x, y) = 1/x at (2, 3)
351    /// let x = MultiDual::<f64, 2>::variable(2.0, 0);
352    /// let recip_x = x.recip();
353    ///
354    /// assert_eq!(recip_x.value, 0.5);        // 1/2
355    /// assert!((recip_x.derivs[0] + 0.25).abs() < 1e-10);  // -1/4
356    /// assert_eq!(recip_x.derivs[1], 0.0);    // ∂(1/x)/∂y = 0
357    /// ```
358    pub fn recip(self) -> Self {
359        let recip_val = self.value.recip();
360        let recip_val_sq = recip_val * recip_val;
361
362        let mut derivs = self.derivs;
363        for deriv in &mut derivs {
364            *deriv = -*deriv * recip_val_sq;
365        }
366
367        Self {
368            value: recip_val,
369            derivs,
370        }
371    }
372
373    /// Exponential function: `exp(a + ∇a) = exp(a) + (exp(a) · ∇a)`
374    ///
375    /// Implements the chain rule: `∂/∂xᵢ(exp(f)) = exp(f) · ∂f/∂xᵢ`
376    ///
377    /// # Examples
378    ///
379    /// ```
380    /// use autodiff::MultiDual;
381    ///
382    /// // f(x, y) = exp(x) at (0, 1)
383    /// let x = MultiDual::<f64, 2>::variable(0.0, 0);
384    /// let f = x.exp();
385    ///
386    /// assert_eq!(f.value, 1.0);       // exp(0) = 1
387    /// assert_eq!(f.derivs[0], 1.0);   // ∂exp(x)/∂x at x=0 is exp(0) = 1
388    /// assert_eq!(f.derivs[1], 0.0);   // ∂exp(x)/∂y = 0
389    /// ```
390    pub fn exp(self) -> Self {
391        let exp_val = self.value.exp();
392        let mut derivs = self.derivs;
393        for deriv in &mut derivs {
394            *deriv = *deriv * exp_val;
395        }
396        Self {
397            value: exp_val,
398            derivs,
399        }
400    }
401
402    /// Natural logarithm: `ln(a + ∇a) = ln(a) + (∇a / a)`
403    ///
404    /// Implements the chain rule: `∂/∂xᵢ(ln(f)) = (∂f/∂xᵢ) / f`
405    ///
406    /// # Examples
407    ///
408    /// ```
409    /// use autodiff::MultiDual;
410    ///
411    /// // f(x, y) = ln(x) at (1, 2)
412    /// let x = MultiDual::<f64, 2>::variable(1.0, 0);
413    /// let f = x.ln();
414    ///
415    /// assert!((f.value - 0.0_f64).abs() < 1e-12);    // ln(1) = 0
416    /// assert!((f.derivs[0] - 1.0_f64).abs() < 1e-12); // ∂ln(x)/∂x at x=1 is 1/1 = 1
417    /// assert_eq!(f.derivs[1], 0.0);               // ∂ln(x)/∂y = 0
418    /// ```
419    pub fn ln(self) -> Self {
420        let ln_val = self.value.ln();
421        let mut derivs = self.derivs;
422        for deriv in &mut derivs {
423            *deriv = *deriv / self.value;
424        }
425        Self {
426            value: ln_val,
427            derivs,
428        }
429    }
430
431    /// Sine function: `sin(a + ∇a) = sin(a) + (cos(a) · ∇a)`
432    ///
433    /// Implements the chain rule: `∂/∂xᵢ(sin(f)) = cos(f) · ∂f/∂xᵢ`
434    ///
435    /// # Examples
436    ///
437    /// ```
438    /// use autodiff::MultiDual;
439    ///
440    /// // f(x, y) = sin(x) at (0, 1)
441    /// let x = MultiDual::<f64, 2>::variable(0.0, 0);
442    /// let f = x.sin();
443    ///
444    /// assert!((f.value - 0.0_f64).abs() < 1e-12);    // sin(0) = 0
445    /// assert!((f.derivs[0] - 1.0_f64).abs() < 1e-12); // ∂sin(x)/∂x at x=0 is cos(0) = 1
446    /// assert_eq!(f.derivs[1], 0.0);               // ∂sin(x)/∂y = 0
447    /// ```
448    pub fn sin(self) -> Self {
449        let sin_val = self.value.sin();
450        let cos_val = self.value.cos();
451        let mut derivs = self.derivs;
452        for deriv in &mut derivs {
453            *deriv = *deriv * cos_val;
454        }
455        Self {
456            value: sin_val,
457            derivs,
458        }
459    }
460
461    /// Cosine function: `cos(a + ∇a) = cos(a) + (-sin(a) · ∇a)`
462    ///
463    /// Implements the chain rule: `∂/∂xᵢ(cos(f)) = -sin(f) · ∂f/∂xᵢ`
464    ///
465    /// # Examples
466    ///
467    /// ```
468    /// use autodiff::MultiDual;
469    ///
470    /// // f(x, y) = cos(x) at (0, 1)
471    /// let x = MultiDual::<f64, 2>::variable(0.0, 0);
472    /// let f = x.cos();
473    ///
474    /// assert!((f.value - 1.0_f64).abs() < 1e-12);    // cos(0) = 1
475    /// assert!((f.derivs[0] - 0.0_f64).abs() < 1e-12); // ∂cos(x)/∂x at x=0 is -sin(0) = 0
476    /// assert_eq!(f.derivs[1], 0.0);               // ∂cos(x)/∂y = 0
477    /// ```
478    pub fn cos(self) -> Self {
479        let cos_val = self.value.cos();
480        let sin_val = self.value.sin();
481        let mut derivs = self.derivs;
482        for deriv in &mut derivs {
483            *deriv = -*deriv * sin_val;
484        }
485        Self {
486            value: cos_val,
487            derivs,
488        }
489    }
490
491    /// Square root: `sqrt(a + ∇a) = sqrt(a) + (∇a / (2·sqrt(a)))`
492    ///
493    /// Implements the chain rule: `∂/∂xᵢ(√f) = (∂f/∂xᵢ) / (2√f)`
494    ///
495    /// # Examples
496    ///
497    /// ```
498    /// use autodiff::MultiDual;
499    ///
500    /// // f(x, y) = √x at (4, 1)
501    /// let x = MultiDual::<f64, 2>::variable(4.0, 0);
502    /// let f = x.sqrt();
503    ///
504    /// assert_eq!(f.value, 2.0);       // √4 = 2
505    /// assert_eq!(f.derivs[0], 0.25);  // ∂√x/∂x at x=4 is 1/(2·2) = 0.25
506    /// assert_eq!(f.derivs[1], 0.0);   // ∂√x/∂y = 0
507    /// ```
508    pub fn sqrt(self) -> Self {
509        let sqrt_val = self.value.sqrt();
510        let two_sqrt = sqrt_val + sqrt_val; // 2 * sqrt(value)
511        let mut derivs = self.derivs;
512        for deriv in &mut derivs {
513            *deriv = *deriv / two_sqrt;
514        }
515        Self {
516            value: sqrt_val,
517            derivs,
518        }
519    }
520}
521
522/// Compute the gradient of a scalar multivariable function in a
523/// single forward pass.
524///
525/// Given a function `f: ℝⁿ → ℝ` and a point in ℝⁿ, computes both the
526/// function value and its gradient ∇f = [∂f/∂x₁, ..., ∂f/∂xₙ] at that
527/// point.
528///
529/// This is the primary high-level API for computing gradients with
530/// MultiDual. It automatically seeds the input variables and
531/// evaluates the function once.
532///
533/// # Type Parameters
534///
535/// - `T`: The numeric type (typically `f64` or `f32`)
536/// - `F`: A function that takes N `MultiDual` inputs and returns a
537///   `MultiDual` output
538/// - `N`: The number of input variables (compile-time constant)
539///
540/// # Arguments
541///
542/// - `f`: The function to differentiate
543/// - `point`: The point at which to evaluate the gradient
544///
545/// # Returns
546///
547/// A tuple `(value, gradient)` where:
548/// - `value`: The function value f(point)
549/// - `gradient`: The gradient ∇f evaluated at point
550///
551/// # Examples
552///
553/// ## Quadratic Function
554///
555/// ```
556/// use autodiff::{MultiDual, gradient};
557///
558/// // f(x, y) = x² + 2xy + y² at (3, 4)
559/// let f = |vars: [MultiDual<f64, 2>; 2]| {
560///     let [x, y] = vars;
561///     let two = MultiDual::constant(2.0);
562///     x * x + two * x * y + y * y
563/// };
564///
565/// let point = [3.0, 4.0];
566/// let (value, grad) = gradient(f, point);
567///
568/// assert_eq!(value, 49.0);    // f(3, 4) = 9 + 24 + 16
569/// assert_eq!(grad[0], 14.0);  // ∂f/∂x = 2x + 2y = 14
570/// assert_eq!(grad[1], 14.0);  // ∂f/∂y = 2x + 2y = 14
571/// ```
572///
573/// ## With Transcendental Functions
574///
575/// ```
576/// use autodiff::{MultiDual, gradient};
577///
578/// // f(x, y, z) = x² + y·exp(z) at (1, 2, 0)
579/// let f = |vars: [MultiDual<f64, 3>; 3]| {
580///     let [x, y, z] = vars;
581///     x * x + y * z.exp()
582/// };
583///
584/// let point = [1.0, 2.0, 0.0];
585/// let (value, grad) = gradient(f, point);
586///
587/// assert_eq!(value, 3.0);     // 1 + 2·1 = 3
588/// assert_eq!(grad[0], 2.0);   // ∂f/∂x = 2x = 2
589/// assert_eq!(grad[1], 1.0);   // ∂f/∂y = exp(z) = 1
590/// assert_eq!(grad[2], 2.0);   // ∂f/∂z = y·exp(z) = 2
591/// ```
592///
593/// ## Rosenbrock Function (optimization benchmark)
594///
595/// ```
596/// use autodiff::{MultiDual, gradient};
597///
598/// // Rosenbrock: f(x, y) = (1-x)² + 100(y-x²)²
599/// let rosenbrock = |vars: [MultiDual<f64, 2>; 2]| {
600///     let [x, y] = vars;
601///     let one = MultiDual::constant(1.0);
602///     let hundred = MultiDual::constant(100.0);
603///
604///     let term1 = one - x;
605///     let term2 = y - x * x;
606///     term1 * term1 + hundred * term2 * term2
607/// };
608///
609/// let point = [1.0, 1.0];  // Global minimum
610/// let (value, grad) = gradient(rosenbrock, point);
611///
612/// assert_eq!(value, 0.0);           // Minimum value is 0
613/// assert_eq!(grad[0], 0.0);         // Gradient is zero at minimum
614/// assert_eq!(grad[1], 0.0);
615/// ```
616pub fn gradient<T, F, const N: usize>(f: F, point: [T; N]) -> (T, [T; N])
617where
618    T: Float,
619    F: Fn([MultiDual<T, N>; N]) -> MultiDual<T, N>,
620{
621    // Seed input variables: each gets its value from point with the
622    // appropriate unit vector for derivatives
623    let vars = std::array::from_fn(|i| MultiDual::variable(point[i], i));
624
625    // Single forward pass through the computation
626    let result = f(vars);
627
628    // Return both the function value and the gradient
629    (result.value, result.derivs)
630}
631
632/// Division: `(a + ∇a) / (b + ∇b) = (a + ∇a) * (1/(b + ∇b))`
633///
634/// Implements division via reciprocal (composition of operations).
635///
636/// # Examples
637///
638/// ```
639/// use autodiff::MultiDual;
640///
641/// // f(x, y) = x / y at (6, 2)
642/// let x = MultiDual::<f64, 2>::variable(6.0, 0);
643/// let y = MultiDual::<f64, 2>::variable(2.0, 1);
644/// let quotient = x / y;
645///
646/// assert_eq!(quotient.value, 3.0);         // 6/2
647/// assert_eq!(quotient.derivs[0], 0.5);     // ∂(x/y)/∂x = 1/y = 0.5
648/// assert_eq!(quotient.derivs[1], -1.5);    // ∂(x/y)/∂y = -x/y² = -1.5
649/// ```
650#[allow(clippy::suspicious_arithmetic_impl)]
651impl<T, const N: usize> Div for MultiDual<T, N>
652where
653    T: Float,
654{
655    type Output = Self;
656
657    fn div(self, rhs: Self) -> Self {
658        self * rhs.recip()
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn constant_has_zero_derivatives() {
668        let c = MultiDual::<f64, 3>::constant(42.0);
669        assert_eq!(c.value, 42.0);
670        assert_eq!(c.derivs, [0.0, 0.0, 0.0]);
671    }
672
673    #[test]
674    fn variable_sets_correct_derivative() {
675        let x = MultiDual::<f64, 3>::variable(3.0, 0);
676        assert_eq!(x.value, 3.0);
677        assert_eq!(x.derivs, [1.0, 0.0, 0.0]);
678
679        let y = MultiDual::<f64, 3>::variable(4.0, 1);
680        assert_eq!(y.value, 4.0);
681        assert_eq!(y.derivs, [0.0, 1.0, 0.0]);
682
683        let z = MultiDual::<f64, 3>::variable(5.0, 2);
684        assert_eq!(z.value, 5.0);
685        assert_eq!(z.derivs, [0.0, 0.0, 1.0]);
686    }
687
688    #[test]
689    fn addition_works() {
690        let x = MultiDual::new(3.0, [1.0, 0.0]);
691        let y = MultiDual::new(4.0, [0.0, 1.0]);
692        let sum = x + y;
693
694        assert_eq!(sum.value, 7.0);
695        assert_eq!(sum.derivs, [1.0, 1.0]);
696    }
697
698    #[test]
699    fn subtraction_works() {
700        let x = MultiDual::new(7.0, [2.0, 3.0]);
701        let y = MultiDual::new(4.0, [1.0, 2.0]);
702        let diff = x - y;
703
704        assert_eq!(diff.value, 3.0);
705        assert_eq!(diff.derivs, [1.0, 1.0]);
706    }
707
708    #[test]
709    fn negation_works() {
710        let x = MultiDual::new(3.0, [1.0, 2.0, 3.0]);
711        let neg_x = -x;
712
713        assert_eq!(neg_x.value, -3.0);
714        assert_eq!(neg_x.derivs, [-1.0, -2.0, -3.0]);
715    }
716
717    #[test]
718    fn multiplication_implements_product_rule() {
719        // f(x, y) = x * y at (3, 4)
720        let x = MultiDual::<f64, 2>::variable(3.0, 0);
721        let y = MultiDual::<f64, 2>::variable(4.0, 1);
722        let product = x * y;
723
724        assert_eq!(product.value, 12.0); // 3 * 4
725        assert_eq!(product.derivs[0], 4.0); // ∂(xy)/∂x = y = 4
726        assert_eq!(product.derivs[1], 3.0); // ∂(xy)/∂y = x = 3
727    }
728
729    #[test]
730    fn recip_implements_inverse_rule() {
731        // f(x, y) = 1/x at (2, 3)
732        let x = MultiDual::<f64, 2>::variable(2.0, 0);
733        let recip_x = x.recip();
734
735        assert_eq!(recip_x.value, 0.5); // 1/2
736        assert!((recip_x.derivs[0] + 0.25).abs() < 1e-10); // -1/4
737        assert_eq!(recip_x.derivs[1], 0.0); // ∂(1/x)/∂y = 0
738    }
739
740    #[test]
741    fn division_quotient_rule() {
742        // f(x, y) = x / y at (6, 2)
743        let x = MultiDual::<f64, 2>::variable(6.0, 0);
744        let y = MultiDual::<f64, 2>::variable(2.0, 1);
745        let quotient = x / y;
746
747        assert_eq!(quotient.value, 3.0); // 6/2
748        assert_eq!(quotient.derivs[0], 0.5); // ∂(x/y)/∂x = 1/y = 0.5
749        assert_eq!(quotient.derivs[1], -1.5); // ∂(x/y)/∂y = -x/y² = -1.5
750    }
751
752    #[test]
753    fn polynomial_gradient() {
754        // f(x, y) = x² + 2xy + y² at (3, 4)
755        // ∂f/∂x = 2x + 2y = 14, ∂f/∂y = 2x + 2y = 14
756        let x = MultiDual::<f64, 2>::variable(3.0, 0);
757        let y = MultiDual::<f64, 2>::variable(4.0, 1);
758        let two = MultiDual::<f64, 2>::constant(2.0);
759
760        let f = x * x + two * x * y + y * y;
761
762        assert_eq!(f.value, 49.0); // 9 + 24 + 16
763        assert_eq!(f.derivs[0], 14.0); // 2x + 2y
764        assert_eq!(f.derivs[1], 14.0); // 2x + 2y
765    }
766
767    #[test]
768    fn three_variable_sum_of_squares() {
769        // f(x, y, z) = x² + y² + z² at (1, 2, 3)
770        let x = MultiDual::<f64, 3>::variable(1.0, 0);
771        let y = MultiDual::<f64, 3>::variable(2.0, 1);
772        let z = MultiDual::<f64, 3>::variable(3.0, 2);
773
774        let f = x * x + y * y + z * z;
775
776        assert_eq!(f.value, 14.0); // 1 + 4 + 9
777        assert_eq!(f.derivs[0], 2.0); // ∂f/∂x = 2x = 2
778        assert_eq!(f.derivs[1], 4.0); // ∂f/∂y = 2y = 4
779        assert_eq!(f.derivs[2], 6.0); // ∂f/∂z = 2z = 6
780    }
781
782    #[test]
783    fn exp_gradient() {
784        // f(x, y) = exp(x) at (0, 1)
785        let x = MultiDual::<f64, 2>::variable(0.0, 0);
786        let f = x.exp();
787
788        assert_eq!(f.value, 1.0); // exp(0) = 1
789        assert_eq!(f.derivs[0], 1.0); // ∂exp(x)/∂x at x=0 is exp(0) = 1
790        assert_eq!(f.derivs[1], 0.0); // ∂exp(x)/∂y = 0
791
792        // f(x, y) = exp(x + y) at (1, 2)
793        let x = MultiDual::<f64, 2>::variable(1.0, 0);
794        let y = MultiDual::<f64, 2>::variable(2.0, 1);
795        let f = (x + y).exp();
796
797        let exp_3 = 3.0_f64.exp();
798        assert!((f.value - exp_3).abs() < 1e-10); // exp(3)
799        assert!((f.derivs[0] - exp_3).abs() < 1e-10); // ∂exp(x+y)/∂x = exp(x+y)
800        assert!((f.derivs[1] - exp_3).abs() < 1e-10); // ∂exp(x+y)/∂y = exp(x+y)
801    }
802
803    #[test]
804    fn ln_gradient() {
805        // f(x, y) = ln(x) at (1, 2)
806        let x = MultiDual::<f64, 2>::variable(1.0, 0);
807        let f = x.ln();
808
809        assert!((f.value - 0.0).abs() < 1e-12); // ln(1) = 0
810        assert!((f.derivs[0] - 1.0).abs() < 1e-12); // ∂ln(x)/∂x at x=1 is 1/1 = 1
811        assert_eq!(f.derivs[1], 0.0); // ∂ln(x)/∂y = 0
812
813        // f(x, y) = ln(x * y) at (2, 3)
814        let x = MultiDual::<f64, 2>::variable(2.0, 0);
815        let y = MultiDual::<f64, 2>::variable(3.0, 1);
816        let f = (x * y).ln();
817
818        assert!((f.value - 6.0_f64.ln()).abs() < 1e-10); // ln(6)
819        assert!((f.derivs[0] - 0.5).abs() < 1e-10); // ∂ln(xy)/∂x = 1/x = 0.5
820        assert!((f.derivs[1] - 1.0 / 3.0).abs() < 1e-10); // ∂ln(xy)/∂y = 1/y = 1/3
821    }
822
823    #[test]
824    fn sin_gradient() {
825        // f(x, y) = sin(x) at (0, 1)
826        let x = MultiDual::<f64, 2>::variable(0.0, 0);
827        let f = x.sin();
828
829        assert!((f.value - 0.0).abs() < 1e-12); // sin(0) = 0
830        assert!((f.derivs[0] - 1.0).abs() < 1e-12); // ∂sin(x)/∂x at x=0 is cos(0) = 1
831        assert_eq!(f.derivs[1], 0.0); // ∂sin(x)/∂y = 0
832    }
833
834    #[test]
835    fn cos_gradient() {
836        // f(x, y) = cos(x) at (0, 1)
837        let x = MultiDual::<f64, 2>::variable(0.0, 0);
838        let f = x.cos();
839
840        assert!((f.value - 1.0).abs() < 1e-12); // cos(0) = 1
841        assert!((f.derivs[0] - 0.0).abs() < 1e-12); // ∂cos(x)/∂x at x=0 is -sin(0) = 0
842        assert_eq!(f.derivs[1], 0.0); // ∂cos(x)/∂y = 0
843    }
844
845    #[test]
846    fn sqrt_gradient() {
847        // f(x, y) = √x at (4, 1)
848        let x = MultiDual::<f64, 2>::variable(4.0, 0);
849        let f = x.sqrt();
850
851        assert_eq!(f.value, 2.0); // √4 = 2
852        assert_eq!(f.derivs[0], 0.25); // ∂√x/∂x at x=4 is 1/(2·2) = 0.25
853        assert_eq!(f.derivs[1], 0.0); // ∂√x/∂y = 0
854    }
855
856    #[test]
857    fn mixed_transcendental_gradient() {
858        // f(x, y) = sin(x) * exp(y) at (0, 0)
859        // ∂f/∂x = cos(x) * exp(y) at (0, 0) = 1 * 1 = 1
860        // ∂f/∂y = sin(x) * exp(y) at (0, 0) = 0 * 1 = 0
861        let x = MultiDual::<f64, 2>::variable(0.0, 0);
862        let y = MultiDual::<f64, 2>::variable(0.0, 1);
863        let f = x.sin() * y.exp();
864
865        assert!((f.value - 0.0).abs() < 1e-12); // sin(0) * exp(0) = 0 * 1 = 0
866        assert!((f.derivs[0] - 1.0).abs() < 1e-12); // cos(0) * exp(0) = 1
867        assert!((f.derivs[1] - 0.0).abs() < 1e-12); // sin(0) * exp(0) = 0
868    }
869
870    #[test]
871    fn chain_rule_with_transcendentals() {
872        // f(x, y) = exp(x²) at (1, 0)
873        // ∂f/∂x = 2x * exp(x²) at x=1 is 2 * e
874        let x = MultiDual::<f64, 2>::variable(1.0, 0);
875        let f = (x * x).exp();
876
877        let e = 1.0_f64.exp();
878        assert!((f.value - e).abs() < 1e-10); // exp(1)
879        assert!((f.derivs[0] - 2.0 * e).abs() < 1e-10); // 2 * exp(1)
880        assert_eq!(f.derivs[1], 0.0);
881
882        // f(x, y, z) = √(x² + y² + z²) at (3, 4, 0)
883        // This is the Euclidean norm
884        // ∂f/∂x = x / √(x² + y² + z²) = 3/5
885        // ∂f/∂y = y / √(x² + y² + z²) = 4/5
886        // ∂f/∂z = z / √(x² + y² + z²) = 0/5 = 0
887        let x = MultiDual::<f64, 3>::variable(3.0, 0);
888        let y = MultiDual::<f64, 3>::variable(4.0, 1);
889        let z = MultiDual::<f64, 3>::variable(0.0, 2);
890        let f = (x * x + y * y + z * z).sqrt();
891
892        assert_eq!(f.value, 5.0); // √(9 + 16 + 0) = 5
893        assert_eq!(f.derivs[0], 0.6); // 3/5
894        assert_eq!(f.derivs[1], 0.8); // 4/5
895        assert_eq!(f.derivs[2], 0.0); // 0/5
896    }
897
898    #[test]
899    fn gradient_quadratic_2d() {
900        use crate::gradient;
901
902        // f(x, y) = x² + 2xy + y² at (3, 4)
903        let f = |vars: [MultiDual<f64, 2>; 2]| {
904            let [x, y] = vars;
905            let two = MultiDual::constant(2.0);
906            x * x + two * x * y + y * y
907        };
908
909        let point = [3.0, 4.0];
910        let (value, grad) = gradient(f, point);
911
912        assert_eq!(value, 49.0); // f(3, 4) = 9 + 24 + 16
913        assert_eq!(grad[0], 14.0); // ∂f/∂x = 2x + 2y = 14
914        assert_eq!(grad[1], 14.0); // ∂f/∂y = 2x + 2y = 14
915    }
916
917    #[test]
918    fn gradient_with_transcendentals() {
919        use crate::gradient;
920
921        // f(x, y, z) = x² + y·exp(z) at (1, 2, 0)
922        let f = |vars: [MultiDual<f64, 3>; 3]| {
923            let [x, y, z] = vars;
924            x * x + y * z.exp()
925        };
926
927        let point = [1.0, 2.0, 0.0];
928        let (value, grad) = gradient(f, point);
929
930        assert_eq!(value, 3.0); // 1 + 2·1 = 3
931        assert_eq!(grad[0], 2.0); // ∂f/∂x = 2x = 2
932        assert_eq!(grad[1], 1.0); // ∂f/∂y = exp(z) = 1
933        assert_eq!(grad[2], 2.0); // ∂f/∂z = y·exp(z) = 2
934    }
935
936    #[test]
937    fn gradient_rosenbrock() {
938        use crate::gradient;
939
940        // Rosenbrock: f(x, y) = (1-x)² + 100(y-x²)²
941        let rosenbrock = |vars: [MultiDual<f64, 2>; 2]| {
942            let [x, y] = vars;
943            let one = MultiDual::constant(1.0);
944            let hundred = MultiDual::constant(100.0);
945
946            let term1 = one - x;
947            let term2 = y - x * x;
948            term1 * term1 + hundred * term2 * term2
949        };
950
951        // Test at global minimum (1, 1)
952        let point = [1.0, 1.0];
953        let (value, grad) = gradient(rosenbrock, point);
954
955        assert_eq!(value, 0.0); // Minimum value is 0
956        assert_eq!(grad[0], 0.0); // Gradient is zero at minimum
957        assert_eq!(grad[1], 0.0);
958
959        // Test at another point (0, 0)
960        let point = [0.0, 0.0];
961        let (value, grad) = gradient(rosenbrock, point);
962
963        assert_eq!(value, 1.0); // f(0, 0) = 1 + 0 = 1
964        assert_eq!(grad[0], -2.0); // ∂f/∂x at (0,0) = -2(1-x) - 400x(y-x²) = -2
965        assert_eq!(grad[1], 0.0); // ∂f/∂y at (0,0) = 200(y-x²) = 0
966    }
967
968    #[test]
969    fn gradient_euclidean_norm() {
970        use crate::gradient;
971
972        // f(x, y, z) = √(x² + y² + z²) at (3, 4, 0)
973        let euclidean_norm = |vars: [MultiDual<f64, 3>; 3]| {
974            let [x, y, z] = vars;
975            (x * x + y * y + z * z).sqrt()
976        };
977
978        let point = [3.0, 4.0, 0.0];
979        let (value, grad) = gradient(euclidean_norm, point);
980
981        assert_eq!(value, 5.0); // √(9 + 16 + 0) = 5
982        assert_eq!(grad[0], 0.6); // x/‖x‖ = 3/5
983        assert_eq!(grad[1], 0.8); // y/‖x‖ = 4/5
984        assert_eq!(grad[2], 0.0); // z/‖x‖ = 0/5
985    }
986
987    #[test]
988    fn gradient_single_variable() {
989        use crate::gradient;
990
991        // f(x) = x³ at x=2
992        // f'(x) = 3x² = 12
993        let f = |vars: [MultiDual<f64, 1>; 1]| {
994            let [x] = vars;
995            x * x * x
996        };
997
998        let point = [2.0];
999        let (value, grad) = gradient(f, point);
1000
1001        assert_eq!(value, 8.0); // 2³ = 8
1002        assert_eq!(grad[0], 12.0); // 3 * 2² = 12
1003    }
1004
1005    #[test]
1006    fn gradient_high_dimensional() {
1007        use crate::gradient;
1008
1009        // f(x₁, x₂, x₃, x₄, x₅) = Σᵢ xᵢ² at (1, 2, 3, 4, 5)
1010        // ∂f/∂xᵢ = 2xᵢ
1011        let f = |vars: [MultiDual<f64, 5>; 5]| {
1012            let [x1, x2, x3, x4, x5] = vars;
1013            x1 * x1 + x2 * x2 + x3 * x3 + x4 * x4 + x5 * x5
1014        };
1015
1016        let point = [1.0, 2.0, 3.0, 4.0, 5.0];
1017        let (value, grad) = gradient(f, point);
1018
1019        assert_eq!(value, 55.0); // 1 + 4 + 9 + 16 + 25
1020        assert_eq!(grad[0], 2.0); // 2 * 1
1021        assert_eq!(grad[1], 4.0); // 2 * 2
1022        assert_eq!(grad[2], 6.0); // 2 * 3
1023        assert_eq!(grad[3], 8.0); // 2 * 4
1024        assert_eq!(grad[4], 10.0); // 2 * 5
1025    }
1026
1027    #[test]
1028    fn gradient_mixed_operations() {
1029        use crate::gradient;
1030
1031        // f(x, y) = sin(x) * exp(y) + ln(x + y) at (1, 0)
1032        let f = |vars: [MultiDual<f64, 2>; 2]| {
1033            let [x, y] = vars;
1034            x.sin() * y.exp() + (x + y).ln()
1035        };
1036
1037        let point = [1.0, 0.0];
1038        let (value, grad) = gradient(f, point);
1039
1040        let sin_1 = 1.0_f64.sin();
1041        let cos_1 = 1.0_f64.cos();
1042
1043        assert!((value - (sin_1 + 0.0)).abs() < 1e-10); // sin(1)*1 + ln(1) = sin(1)
1044                                                        // ∂f/∂x = cos(x)*exp(y) + 1/(x+y) at (1,0) = cos(1) + 1
1045        assert!((grad[0] - (cos_1 + 1.0)).abs() < 1e-10);
1046        // ∂f/∂y = sin(x)*exp(y) + 1/(x+y) at (1,0) = sin(1) + 1
1047        assert!((grad[1] - (sin_1 + 1.0)).abs() < 1e-10);
1048    }
1049}