autodiff/multidual.rs
1//! Multi-component dual numbers for multivariable automatic
2//! differentiation.
3//!
4//! A multi-component dual number tracks a value and multiple partial
5//! derivatives simultaneously, enabling computation of gradients in a
6//! **single forward pass**.
7//!
8//! # Mathematical Background
9//!
10//! For a function f: ℝⁿ → ℝ, `MultiDual<T, N>` represents a value and
11//! its gradient ∇f = [∂f/∂x₁, ∂f/∂x₂, ..., ∂f/∂xₙ] simultaneously.
12//!
13//! Arithmetic operations extend naturally from single-variable dual
14//! numbers:
15//!
16//! - `(a + ∇a) + (b + ∇b) = (a+b) + (∇a+∇b)`
17//! - `-(a + ∇a) = -a + (-∇a)`
18//! - `(a + ∇a) - (b + ∇b) = (a-b) + (∇a-∇b)`
19//! - `(a + ∇a) * (b + ∇b) = ab + (b∇a + a∇b)`
20//! - `1/(b + ∇b) = (1/b) + (-∇b/b²)`
21//!
22//! Each operation updates all N derivative components at once,
23//! computing the full gradient in a single pass through the
24//! computation.
25//!
26//! # Example
27//!
28//! ```
29//! use autodiff::{MultiDual, gradient};
30//!
31//! // Compute ∇f for f(x, y) = x² + 2xy + y² at (3, 4)
32//! let f = |vars: [MultiDual<f64, 2>; 2]| {
33//! let [x, y] = vars;
34//! let two = MultiDual::constant(2.0);
35//! x * x + two * x * y + y * y
36//! };
37//!
38//! let point = [3.0, 4.0];
39//! let (value, grad) = gradient(f, point);
40//!
41//! assert_eq!(value, 49.0); // f(3, 4) = 9 + 24 + 16 = 49
42//! assert_eq!(grad[0], 14.0); // ∂f/∂x = 2x + 2y = 14
43//! assert_eq!(grad[1], 14.0); // ∂f/∂y = 2x + 2y = 14
44//! ```
45//!
46//! # Efficiency
47//!
48//! Computing the gradient requires **1 forward pass** with
49//! `MultiDual<T, N>`, compared to n passes if using `Dual<T>` to
50//! compute each partial derivative separately. For n=10 inputs, this
51//! is a 10x speedup.
52//!
53//! # Use Cases
54//!
55
56//! - Gradient-based optimization (gradient descent, Newton's method)
57//! - Neural network backpropagation alternatives
58//! - Sensitivity analysis with multiple parameters
59//! - Scientific computing with multivariable functions
60
61use num_traits::{Float, One, Zero};
62use std::ops::{Add, Div, Mul, Neg, Sub};
63
64/// A multi-component dual number representing a value and N partial
65/// derivatives.
66///
67/// `MultiDual<T, N>` represents a value along with its gradient
68/// [∂f/∂x₁, ..., ∂f/∂xₙ] for forward-mode automatic differentiation
69/// of multivariable functions.
70///
71/// # Type Parameters
72///
73/// - `T`: The numeric type (typically `f64` or `f32`)
74/// - `N`: The number of input variables (compile-time constant)
75///
76/// # Examples
77///
78/// ## Creating Variables
79///
80/// ```
81/// use autodiff::MultiDual;
82///
83/// // Create the first variable x with value 3.0 (∂/∂x = 1, ∂/∂y = 0)
84/// let x = MultiDual::<f64, 2>::variable(3.0, 0);
85/// assert_eq!(x.value, 3.0);
86/// assert_eq!(x.derivs, [1.0, 0.0]);
87///
88/// // Create the second variable y with value 4.0 (∂/∂x = 0, ∂/∂y = 1)
89/// let y = MultiDual::<f64, 2>::variable(4.0, 1);
90/// assert_eq!(y.value, 4.0);
91/// assert_eq!(y.derivs, [0.0, 1.0]);
92///
93/// // Create a constant (∂/∂x = 0, ∂/∂y = 0)
94/// let c = MultiDual::<f64, 2>::constant(2.0);
95/// assert_eq!(c.value, 2.0);
96/// assert_eq!(c.derivs, [0.0, 0.0]);
97/// ```
98///
99/// ## Computing Gradients
100///
101/// ```
102/// use autodiff::MultiDual;
103///
104/// // f(x, y) = x² + y²
105/// let x = MultiDual::<f64, 2>::variable(3.0, 0);
106/// let y = MultiDual::<f64, 2>::variable(4.0, 1);
107///
108/// let f = x * x + y * y;
109///
110/// assert_eq!(f.value, 25.0); // 3² + 4² = 25
111/// assert_eq!(f.derivs[0], 6.0); // ∂f/∂x = 2x = 6
112/// assert_eq!(f.derivs[1], 8.0); // ∂f/∂y = 2y = 8
113/// ```
114#[derive(Debug, Clone, Copy, PartialEq)]
115pub struct MultiDual<T, const N: usize> {
116 /// The primal value (function output)
117 pub value: T,
118 /// The partial derivatives [∂f/∂x₁, ∂f/∂x₂, ..., ∂f/∂xₙ]
119 pub derivs: [T; N],
120}
121
122impl<T, const N: usize> MultiDual<T, N>
123where
124 T: Copy,
125{
126 /// Create a dual number with explicit value and derivatives.
127 ///
128 /// # Examples
129 ///
130 /// ```
131 /// use autodiff::MultiDual;
132 ///
133 /// let d = MultiDual::new(5.0, [1.0, 2.0, 3.0]);
134 /// assert_eq!(d.value, 5.0);
135 /// assert_eq!(d.derivs, [1.0, 2.0, 3.0]);
136 /// ```
137 pub fn new(value: T, derivs: [T; N]) -> Self {
138 Self { value, derivs }
139 }
140
141 /// Create a constant (all partial derivatives are zero).
142 ///
143 /// # Examples
144 ///
145 /// ```
146 /// use autodiff::MultiDual;
147 ///
148 /// let c = MultiDual::<f64, 3>::constant(42.0);
149 /// assert_eq!(c.value, 42.0);
150 /// assert_eq!(c.derivs, [0.0, 0.0, 0.0]);
151 /// ```
152 pub fn constant(value: T) -> Self
153 where
154 T: Zero,
155 {
156 Self {
157 value,
158 derivs: [T::zero(); N],
159 }
160 }
161
162 /// Create the i-th input variable.
163 ///
164 /// Sets `derivs[index] = 1` and all other derivatives to zero,
165 /// representing ∂xᵢ/∂xⱼ = δᵢⱼ (Kronecker delta).
166 ///
167 /// # Panics
168 ///
169 /// Panics if `index >= N`.
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// use autodiff::MultiDual;
175 ///
176 /// // Create x (first variable)
177 /// let x = MultiDual::<f64, 2>::variable(3.0, 0);
178 /// assert_eq!(x.value, 3.0);
179 /// assert_eq!(x.derivs, [1.0, 0.0]);
180 ///
181 /// // Create y (second variable)
182 /// let y = MultiDual::<f64, 2>::variable(4.0, 1);
183 /// assert_eq!(y.value, 4.0);
184 /// assert_eq!(y.derivs, [0.0, 1.0]);
185 /// ```
186 pub fn variable(value: T, index: usize) -> Self
187 where
188 T: Zero + One,
189 {
190 assert!(
191 index < N,
192 "Variable index {} out of bounds for N={}",
193 index,
194 N
195 );
196 let mut derivs = [T::zero(); N];
197 derivs[index] = T::one();
198 Self { value, derivs }
199 }
200}
201
202/// Addition: `(a + ∇a) + (b + ∇b) = (a+b) + (∇a+∇b)`
203///
204/// # Examples
205///
206/// ```
207/// use autodiff::MultiDual;
208///
209/// let x = MultiDual::new(3.0, [1.0, 0.0]);
210/// let y = MultiDual::new(4.0, [0.0, 1.0]);
211/// let sum = x + y;
212///
213/// assert_eq!(sum.value, 7.0);
214/// assert_eq!(sum.derivs, [1.0, 1.0]);
215/// ```
216impl<T, const N: usize> Add for MultiDual<T, N>
217where
218 T: Add<Output = T> + Copy,
219{
220 type Output = Self;
221
222 fn add(self, rhs: Self) -> Self {
223 let mut derivs = self.derivs;
224 for (deriv, rhs_deriv) in derivs.iter_mut().zip(rhs.derivs.iter()) {
225 *deriv = *deriv + *rhs_deriv;
226 }
227 Self {
228 value: self.value + rhs.value,
229 derivs,
230 }
231 }
232}
233
234/// Subtraction: `(a + ∇a) - (b + ∇b) = (a-b) + (∇a-∇b)`
235///
236/// # Examples
237///
238/// ```
239/// use autodiff::MultiDual;
240///
241/// let x = MultiDual::new(7.0, [2.0, 3.0]);
242/// let y = MultiDual::new(4.0, [1.0, 2.0]);
243/// let diff = x - y;
244///
245/// assert_eq!(diff.value, 3.0);
246/// assert_eq!(diff.derivs, [1.0, 1.0]);
247/// ```
248impl<T, const N: usize> Sub for MultiDual<T, N>
249where
250 T: Sub<Output = T> + Copy,
251{
252 type Output = Self;
253
254 fn sub(self, rhs: Self) -> Self {
255 let mut derivs = self.derivs;
256 for (deriv, rhs_deriv) in derivs.iter_mut().zip(rhs.derivs.iter()) {
257 *deriv = *deriv - *rhs_deriv;
258 }
259 Self {
260 value: self.value - rhs.value,
261 derivs,
262 }
263 }
264}
265
266/// Multiplication: `(a + ∇a) * (b + ∇b) = ab + (b∇a + a∇b)`
267///
268/// This implements the product rule automatically.
269///
270/// # Examples
271///
272/// ```
273/// use autodiff::MultiDual;
274///
275/// // f(x, y) = x * y at (3, 4)
276/// let x = MultiDual::<f64, 2>::variable(3.0, 0);
277/// let y = MultiDual::<f64, 2>::variable(4.0, 1);
278/// let product = x * y;
279///
280/// assert_eq!(product.value, 12.0); // 3 * 4
281/// assert_eq!(product.derivs[0], 4.0); // ∂(xy)/∂x = y = 4
282/// assert_eq!(product.derivs[1], 3.0); // ∂(xy)/∂y = x = 3
283/// ```
284impl<T, const N: usize> Mul for MultiDual<T, N>
285where
286 T: Mul<Output = T> + Add<Output = T> + Copy,
287{
288 type Output = Self;
289
290 fn mul(self, rhs: Self) -> Self {
291 let mut derivs = [self.value; N];
292 for (deriv, (self_deriv, rhs_deriv)) in derivs
293 .iter_mut()
294 .zip(self.derivs.iter().zip(rhs.derivs.iter()))
295 {
296 // Product rule: (f*g)' = f'*g + f*g'
297 *deriv = *self_deriv * rhs.value + self.value * *rhs_deriv;
298 }
299 Self {
300 value: self.value * rhs.value,
301 derivs,
302 }
303 }
304}
305
306/// Negation: `-(a + ∇a) = -a + (-∇a)`
307///
308/// # Examples
309///
310/// ```
311/// use autodiff::MultiDual;
312///
313/// let x = MultiDual::new(3.0, [1.0, 2.0]);
314/// let neg_x = -x;
315///
316/// assert_eq!(neg_x.value, -3.0);
317/// assert_eq!(neg_x.derivs, [-1.0, -2.0]);
318/// ```
319impl<T, const N: usize> Neg for MultiDual<T, N>
320where
321 T: Neg<Output = T> + Copy,
322{
323 type Output = Self;
324
325 fn neg(self) -> Self {
326 let mut derivs = self.derivs;
327 for deriv in &mut derivs {
328 *deriv = -*deriv;
329 }
330 Self {
331 value: -self.value,
332 derivs,
333 }
334 }
335}
336
337impl<T, const N: usize> MultiDual<T, N>
338where
339 T: Float,
340{
341 /// Reciprocal: `1/(b + ∇b) = (1/b) + (-∇b/b²)`
342 ///
343 /// Implements the rule: (1/g)′ = -g′/g² (note: g ≠ 0)
344 ///
345 /// # Examples
346 ///
347 /// ```
348 /// use autodiff::MultiDual;
349 ///
350 /// // f(x, y) = 1/x at (2, 3)
351 /// let x = MultiDual::<f64, 2>::variable(2.0, 0);
352 /// let recip_x = x.recip();
353 ///
354 /// assert_eq!(recip_x.value, 0.5); // 1/2
355 /// assert!((recip_x.derivs[0] + 0.25).abs() < 1e-10); // -1/4
356 /// assert_eq!(recip_x.derivs[1], 0.0); // ∂(1/x)/∂y = 0
357 /// ```
358 pub fn recip(self) -> Self {
359 let recip_val = self.value.recip();
360 let recip_val_sq = recip_val * recip_val;
361
362 let mut derivs = self.derivs;
363 for deriv in &mut derivs {
364 *deriv = -*deriv * recip_val_sq;
365 }
366
367 Self {
368 value: recip_val,
369 derivs,
370 }
371 }
372
373 /// Exponential function: `exp(a + ∇a) = exp(a) + (exp(a) · ∇a)`
374 ///
375 /// Implements the chain rule: `∂/∂xᵢ(exp(f)) = exp(f) · ∂f/∂xᵢ`
376 ///
377 /// # Examples
378 ///
379 /// ```
380 /// use autodiff::MultiDual;
381 ///
382 /// // f(x, y) = exp(x) at (0, 1)
383 /// let x = MultiDual::<f64, 2>::variable(0.0, 0);
384 /// let f = x.exp();
385 ///
386 /// assert_eq!(f.value, 1.0); // exp(0) = 1
387 /// assert_eq!(f.derivs[0], 1.0); // ∂exp(x)/∂x at x=0 is exp(0) = 1
388 /// assert_eq!(f.derivs[1], 0.0); // ∂exp(x)/∂y = 0
389 /// ```
390 pub fn exp(self) -> Self {
391 let exp_val = self.value.exp();
392 let mut derivs = self.derivs;
393 for deriv in &mut derivs {
394 *deriv = *deriv * exp_val;
395 }
396 Self {
397 value: exp_val,
398 derivs,
399 }
400 }
401
402 /// Natural logarithm: `ln(a + ∇a) = ln(a) + (∇a / a)`
403 ///
404 /// Implements the chain rule: `∂/∂xᵢ(ln(f)) = (∂f/∂xᵢ) / f`
405 ///
406 /// # Examples
407 ///
408 /// ```
409 /// use autodiff::MultiDual;
410 ///
411 /// // f(x, y) = ln(x) at (1, 2)
412 /// let x = MultiDual::<f64, 2>::variable(1.0, 0);
413 /// let f = x.ln();
414 ///
415 /// assert!((f.value - 0.0_f64).abs() < 1e-12); // ln(1) = 0
416 /// assert!((f.derivs[0] - 1.0_f64).abs() < 1e-12); // ∂ln(x)/∂x at x=1 is 1/1 = 1
417 /// assert_eq!(f.derivs[1], 0.0); // ∂ln(x)/∂y = 0
418 /// ```
419 pub fn ln(self) -> Self {
420 let ln_val = self.value.ln();
421 let mut derivs = self.derivs;
422 for deriv in &mut derivs {
423 *deriv = *deriv / self.value;
424 }
425 Self {
426 value: ln_val,
427 derivs,
428 }
429 }
430
431 /// Sine function: `sin(a + ∇a) = sin(a) + (cos(a) · ∇a)`
432 ///
433 /// Implements the chain rule: `∂/∂xᵢ(sin(f)) = cos(f) · ∂f/∂xᵢ`
434 ///
435 /// # Examples
436 ///
437 /// ```
438 /// use autodiff::MultiDual;
439 ///
440 /// // f(x, y) = sin(x) at (0, 1)
441 /// let x = MultiDual::<f64, 2>::variable(0.0, 0);
442 /// let f = x.sin();
443 ///
444 /// assert!((f.value - 0.0_f64).abs() < 1e-12); // sin(0) = 0
445 /// assert!((f.derivs[0] - 1.0_f64).abs() < 1e-12); // ∂sin(x)/∂x at x=0 is cos(0) = 1
446 /// assert_eq!(f.derivs[1], 0.0); // ∂sin(x)/∂y = 0
447 /// ```
448 pub fn sin(self) -> Self {
449 let sin_val = self.value.sin();
450 let cos_val = self.value.cos();
451 let mut derivs = self.derivs;
452 for deriv in &mut derivs {
453 *deriv = *deriv * cos_val;
454 }
455 Self {
456 value: sin_val,
457 derivs,
458 }
459 }
460
461 /// Cosine function: `cos(a + ∇a) = cos(a) + (-sin(a) · ∇a)`
462 ///
463 /// Implements the chain rule: `∂/∂xᵢ(cos(f)) = -sin(f) · ∂f/∂xᵢ`
464 ///
465 /// # Examples
466 ///
467 /// ```
468 /// use autodiff::MultiDual;
469 ///
470 /// // f(x, y) = cos(x) at (0, 1)
471 /// let x = MultiDual::<f64, 2>::variable(0.0, 0);
472 /// let f = x.cos();
473 ///
474 /// assert!((f.value - 1.0_f64).abs() < 1e-12); // cos(0) = 1
475 /// assert!((f.derivs[0] - 0.0_f64).abs() < 1e-12); // ∂cos(x)/∂x at x=0 is -sin(0) = 0
476 /// assert_eq!(f.derivs[1], 0.0); // ∂cos(x)/∂y = 0
477 /// ```
478 pub fn cos(self) -> Self {
479 let cos_val = self.value.cos();
480 let sin_val = self.value.sin();
481 let mut derivs = self.derivs;
482 for deriv in &mut derivs {
483 *deriv = -*deriv * sin_val;
484 }
485 Self {
486 value: cos_val,
487 derivs,
488 }
489 }
490
491 /// Square root: `sqrt(a + ∇a) = sqrt(a) + (∇a / (2·sqrt(a)))`
492 ///
493 /// Implements the chain rule: `∂/∂xᵢ(√f) = (∂f/∂xᵢ) / (2√f)`
494 ///
495 /// # Examples
496 ///
497 /// ```
498 /// use autodiff::MultiDual;
499 ///
500 /// // f(x, y) = √x at (4, 1)
501 /// let x = MultiDual::<f64, 2>::variable(4.0, 0);
502 /// let f = x.sqrt();
503 ///
504 /// assert_eq!(f.value, 2.0); // √4 = 2
505 /// assert_eq!(f.derivs[0], 0.25); // ∂√x/∂x at x=4 is 1/(2·2) = 0.25
506 /// assert_eq!(f.derivs[1], 0.0); // ∂√x/∂y = 0
507 /// ```
508 pub fn sqrt(self) -> Self {
509 let sqrt_val = self.value.sqrt();
510 let two_sqrt = sqrt_val + sqrt_val; // 2 * sqrt(value)
511 let mut derivs = self.derivs;
512 for deriv in &mut derivs {
513 *deriv = *deriv / two_sqrt;
514 }
515 Self {
516 value: sqrt_val,
517 derivs,
518 }
519 }
520}
521
522/// Compute the gradient of a scalar multivariable function in a
523/// single forward pass.
524///
525/// Given a function `f: ℝⁿ → ℝ` and a point in ℝⁿ, computes both the
526/// function value and its gradient ∇f = [∂f/∂x₁, ..., ∂f/∂xₙ] at that
527/// point.
528///
529/// This is the primary high-level API for computing gradients with
530/// MultiDual. It automatically seeds the input variables and
531/// evaluates the function once.
532///
533/// # Type Parameters
534///
535/// - `T`: The numeric type (typically `f64` or `f32`)
536/// - `F`: A function that takes N `MultiDual` inputs and returns a
537/// `MultiDual` output
538/// - `N`: The number of input variables (compile-time constant)
539///
540/// # Arguments
541///
542/// - `f`: The function to differentiate
543/// - `point`: The point at which to evaluate the gradient
544///
545/// # Returns
546///
547/// A tuple `(value, gradient)` where:
548/// - `value`: The function value f(point)
549/// - `gradient`: The gradient ∇f evaluated at point
550///
551/// # Examples
552///
553/// ## Quadratic Function
554///
555/// ```
556/// use autodiff::{MultiDual, gradient};
557///
558/// // f(x, y) = x² + 2xy + y² at (3, 4)
559/// let f = |vars: [MultiDual<f64, 2>; 2]| {
560/// let [x, y] = vars;
561/// let two = MultiDual::constant(2.0);
562/// x * x + two * x * y + y * y
563/// };
564///
565/// let point = [3.0, 4.0];
566/// let (value, grad) = gradient(f, point);
567///
568/// assert_eq!(value, 49.0); // f(3, 4) = 9 + 24 + 16
569/// assert_eq!(grad[0], 14.0); // ∂f/∂x = 2x + 2y = 14
570/// assert_eq!(grad[1], 14.0); // ∂f/∂y = 2x + 2y = 14
571/// ```
572///
573/// ## With Transcendental Functions
574///
575/// ```
576/// use autodiff::{MultiDual, gradient};
577///
578/// // f(x, y, z) = x² + y·exp(z) at (1, 2, 0)
579/// let f = |vars: [MultiDual<f64, 3>; 3]| {
580/// let [x, y, z] = vars;
581/// x * x + y * z.exp()
582/// };
583///
584/// let point = [1.0, 2.0, 0.0];
585/// let (value, grad) = gradient(f, point);
586///
587/// assert_eq!(value, 3.0); // 1 + 2·1 = 3
588/// assert_eq!(grad[0], 2.0); // ∂f/∂x = 2x = 2
589/// assert_eq!(grad[1], 1.0); // ∂f/∂y = exp(z) = 1
590/// assert_eq!(grad[2], 2.0); // ∂f/∂z = y·exp(z) = 2
591/// ```
592///
593/// ## Rosenbrock Function (optimization benchmark)
594///
595/// ```
596/// use autodiff::{MultiDual, gradient};
597///
598/// // Rosenbrock: f(x, y) = (1-x)² + 100(y-x²)²
599/// let rosenbrock = |vars: [MultiDual<f64, 2>; 2]| {
600/// let [x, y] = vars;
601/// let one = MultiDual::constant(1.0);
602/// let hundred = MultiDual::constant(100.0);
603///
604/// let term1 = one - x;
605/// let term2 = y - x * x;
606/// term1 * term1 + hundred * term2 * term2
607/// };
608///
609/// let point = [1.0, 1.0]; // Global minimum
610/// let (value, grad) = gradient(rosenbrock, point);
611///
612/// assert_eq!(value, 0.0); // Minimum value is 0
613/// assert_eq!(grad[0], 0.0); // Gradient is zero at minimum
614/// assert_eq!(grad[1], 0.0);
615/// ```
616pub fn gradient<T, F, const N: usize>(f: F, point: [T; N]) -> (T, [T; N])
617where
618 T: Float,
619 F: Fn([MultiDual<T, N>; N]) -> MultiDual<T, N>,
620{
621 // Seed input variables: each gets its value from point with the
622 // appropriate unit vector for derivatives
623 let vars = std::array::from_fn(|i| MultiDual::variable(point[i], i));
624
625 // Single forward pass through the computation
626 let result = f(vars);
627
628 // Return both the function value and the gradient
629 (result.value, result.derivs)
630}
631
632/// Division: `(a + ∇a) / (b + ∇b) = (a + ∇a) * (1/(b + ∇b))`
633///
634/// Implements division via reciprocal (composition of operations).
635///
636/// # Examples
637///
638/// ```
639/// use autodiff::MultiDual;
640///
641/// // f(x, y) = x / y at (6, 2)
642/// let x = MultiDual::<f64, 2>::variable(6.0, 0);
643/// let y = MultiDual::<f64, 2>::variable(2.0, 1);
644/// let quotient = x / y;
645///
646/// assert_eq!(quotient.value, 3.0); // 6/2
647/// assert_eq!(quotient.derivs[0], 0.5); // ∂(x/y)/∂x = 1/y = 0.5
648/// assert_eq!(quotient.derivs[1], -1.5); // ∂(x/y)/∂y = -x/y² = -1.5
649/// ```
650#[allow(clippy::suspicious_arithmetic_impl)]
651impl<T, const N: usize> Div for MultiDual<T, N>
652where
653 T: Float,
654{
655 type Output = Self;
656
657 fn div(self, rhs: Self) -> Self {
658 self * rhs.recip()
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn constant_has_zero_derivatives() {
668 let c = MultiDual::<f64, 3>::constant(42.0);
669 assert_eq!(c.value, 42.0);
670 assert_eq!(c.derivs, [0.0, 0.0, 0.0]);
671 }
672
673 #[test]
674 fn variable_sets_correct_derivative() {
675 let x = MultiDual::<f64, 3>::variable(3.0, 0);
676 assert_eq!(x.value, 3.0);
677 assert_eq!(x.derivs, [1.0, 0.0, 0.0]);
678
679 let y = MultiDual::<f64, 3>::variable(4.0, 1);
680 assert_eq!(y.value, 4.0);
681 assert_eq!(y.derivs, [0.0, 1.0, 0.0]);
682
683 let z = MultiDual::<f64, 3>::variable(5.0, 2);
684 assert_eq!(z.value, 5.0);
685 assert_eq!(z.derivs, [0.0, 0.0, 1.0]);
686 }
687
688 #[test]
689 fn addition_works() {
690 let x = MultiDual::new(3.0, [1.0, 0.0]);
691 let y = MultiDual::new(4.0, [0.0, 1.0]);
692 let sum = x + y;
693
694 assert_eq!(sum.value, 7.0);
695 assert_eq!(sum.derivs, [1.0, 1.0]);
696 }
697
698 #[test]
699 fn subtraction_works() {
700 let x = MultiDual::new(7.0, [2.0, 3.0]);
701 let y = MultiDual::new(4.0, [1.0, 2.0]);
702 let diff = x - y;
703
704 assert_eq!(diff.value, 3.0);
705 assert_eq!(diff.derivs, [1.0, 1.0]);
706 }
707
708 #[test]
709 fn negation_works() {
710 let x = MultiDual::new(3.0, [1.0, 2.0, 3.0]);
711 let neg_x = -x;
712
713 assert_eq!(neg_x.value, -3.0);
714 assert_eq!(neg_x.derivs, [-1.0, -2.0, -3.0]);
715 }
716
717 #[test]
718 fn multiplication_implements_product_rule() {
719 // f(x, y) = x * y at (3, 4)
720 let x = MultiDual::<f64, 2>::variable(3.0, 0);
721 let y = MultiDual::<f64, 2>::variable(4.0, 1);
722 let product = x * y;
723
724 assert_eq!(product.value, 12.0); // 3 * 4
725 assert_eq!(product.derivs[0], 4.0); // ∂(xy)/∂x = y = 4
726 assert_eq!(product.derivs[1], 3.0); // ∂(xy)/∂y = x = 3
727 }
728
729 #[test]
730 fn recip_implements_inverse_rule() {
731 // f(x, y) = 1/x at (2, 3)
732 let x = MultiDual::<f64, 2>::variable(2.0, 0);
733 let recip_x = x.recip();
734
735 assert_eq!(recip_x.value, 0.5); // 1/2
736 assert!((recip_x.derivs[0] + 0.25).abs() < 1e-10); // -1/4
737 assert_eq!(recip_x.derivs[1], 0.0); // ∂(1/x)/∂y = 0
738 }
739
740 #[test]
741 fn division_quotient_rule() {
742 // f(x, y) = x / y at (6, 2)
743 let x = MultiDual::<f64, 2>::variable(6.0, 0);
744 let y = MultiDual::<f64, 2>::variable(2.0, 1);
745 let quotient = x / y;
746
747 assert_eq!(quotient.value, 3.0); // 6/2
748 assert_eq!(quotient.derivs[0], 0.5); // ∂(x/y)/∂x = 1/y = 0.5
749 assert_eq!(quotient.derivs[1], -1.5); // ∂(x/y)/∂y = -x/y² = -1.5
750 }
751
752 #[test]
753 fn polynomial_gradient() {
754 // f(x, y) = x² + 2xy + y² at (3, 4)
755 // ∂f/∂x = 2x + 2y = 14, ∂f/∂y = 2x + 2y = 14
756 let x = MultiDual::<f64, 2>::variable(3.0, 0);
757 let y = MultiDual::<f64, 2>::variable(4.0, 1);
758 let two = MultiDual::<f64, 2>::constant(2.0);
759
760 let f = x * x + two * x * y + y * y;
761
762 assert_eq!(f.value, 49.0); // 9 + 24 + 16
763 assert_eq!(f.derivs[0], 14.0); // 2x + 2y
764 assert_eq!(f.derivs[1], 14.0); // 2x + 2y
765 }
766
767 #[test]
768 fn three_variable_sum_of_squares() {
769 // f(x, y, z) = x² + y² + z² at (1, 2, 3)
770 let x = MultiDual::<f64, 3>::variable(1.0, 0);
771 let y = MultiDual::<f64, 3>::variable(2.0, 1);
772 let z = MultiDual::<f64, 3>::variable(3.0, 2);
773
774 let f = x * x + y * y + z * z;
775
776 assert_eq!(f.value, 14.0); // 1 + 4 + 9
777 assert_eq!(f.derivs[0], 2.0); // ∂f/∂x = 2x = 2
778 assert_eq!(f.derivs[1], 4.0); // ∂f/∂y = 2y = 4
779 assert_eq!(f.derivs[2], 6.0); // ∂f/∂z = 2z = 6
780 }
781
782 #[test]
783 fn exp_gradient() {
784 // f(x, y) = exp(x) at (0, 1)
785 let x = MultiDual::<f64, 2>::variable(0.0, 0);
786 let f = x.exp();
787
788 assert_eq!(f.value, 1.0); // exp(0) = 1
789 assert_eq!(f.derivs[0], 1.0); // ∂exp(x)/∂x at x=0 is exp(0) = 1
790 assert_eq!(f.derivs[1], 0.0); // ∂exp(x)/∂y = 0
791
792 // f(x, y) = exp(x + y) at (1, 2)
793 let x = MultiDual::<f64, 2>::variable(1.0, 0);
794 let y = MultiDual::<f64, 2>::variable(2.0, 1);
795 let f = (x + y).exp();
796
797 let exp_3 = 3.0_f64.exp();
798 assert!((f.value - exp_3).abs() < 1e-10); // exp(3)
799 assert!((f.derivs[0] - exp_3).abs() < 1e-10); // ∂exp(x+y)/∂x = exp(x+y)
800 assert!((f.derivs[1] - exp_3).abs() < 1e-10); // ∂exp(x+y)/∂y = exp(x+y)
801 }
802
803 #[test]
804 fn ln_gradient() {
805 // f(x, y) = ln(x) at (1, 2)
806 let x = MultiDual::<f64, 2>::variable(1.0, 0);
807 let f = x.ln();
808
809 assert!((f.value - 0.0).abs() < 1e-12); // ln(1) = 0
810 assert!((f.derivs[0] - 1.0).abs() < 1e-12); // ∂ln(x)/∂x at x=1 is 1/1 = 1
811 assert_eq!(f.derivs[1], 0.0); // ∂ln(x)/∂y = 0
812
813 // f(x, y) = ln(x * y) at (2, 3)
814 let x = MultiDual::<f64, 2>::variable(2.0, 0);
815 let y = MultiDual::<f64, 2>::variable(3.0, 1);
816 let f = (x * y).ln();
817
818 assert!((f.value - 6.0_f64.ln()).abs() < 1e-10); // ln(6)
819 assert!((f.derivs[0] - 0.5).abs() < 1e-10); // ∂ln(xy)/∂x = 1/x = 0.5
820 assert!((f.derivs[1] - 1.0 / 3.0).abs() < 1e-10); // ∂ln(xy)/∂y = 1/y = 1/3
821 }
822
823 #[test]
824 fn sin_gradient() {
825 // f(x, y) = sin(x) at (0, 1)
826 let x = MultiDual::<f64, 2>::variable(0.0, 0);
827 let f = x.sin();
828
829 assert!((f.value - 0.0).abs() < 1e-12); // sin(0) = 0
830 assert!((f.derivs[0] - 1.0).abs() < 1e-12); // ∂sin(x)/∂x at x=0 is cos(0) = 1
831 assert_eq!(f.derivs[1], 0.0); // ∂sin(x)/∂y = 0
832 }
833
834 #[test]
835 fn cos_gradient() {
836 // f(x, y) = cos(x) at (0, 1)
837 let x = MultiDual::<f64, 2>::variable(0.0, 0);
838 let f = x.cos();
839
840 assert!((f.value - 1.0).abs() < 1e-12); // cos(0) = 1
841 assert!((f.derivs[0] - 0.0).abs() < 1e-12); // ∂cos(x)/∂x at x=0 is -sin(0) = 0
842 assert_eq!(f.derivs[1], 0.0); // ∂cos(x)/∂y = 0
843 }
844
845 #[test]
846 fn sqrt_gradient() {
847 // f(x, y) = √x at (4, 1)
848 let x = MultiDual::<f64, 2>::variable(4.0, 0);
849 let f = x.sqrt();
850
851 assert_eq!(f.value, 2.0); // √4 = 2
852 assert_eq!(f.derivs[0], 0.25); // ∂√x/∂x at x=4 is 1/(2·2) = 0.25
853 assert_eq!(f.derivs[1], 0.0); // ∂√x/∂y = 0
854 }
855
856 #[test]
857 fn mixed_transcendental_gradient() {
858 // f(x, y) = sin(x) * exp(y) at (0, 0)
859 // ∂f/∂x = cos(x) * exp(y) at (0, 0) = 1 * 1 = 1
860 // ∂f/∂y = sin(x) * exp(y) at (0, 0) = 0 * 1 = 0
861 let x = MultiDual::<f64, 2>::variable(0.0, 0);
862 let y = MultiDual::<f64, 2>::variable(0.0, 1);
863 let f = x.sin() * y.exp();
864
865 assert!((f.value - 0.0).abs() < 1e-12); // sin(0) * exp(0) = 0 * 1 = 0
866 assert!((f.derivs[0] - 1.0).abs() < 1e-12); // cos(0) * exp(0) = 1
867 assert!((f.derivs[1] - 0.0).abs() < 1e-12); // sin(0) * exp(0) = 0
868 }
869
870 #[test]
871 fn chain_rule_with_transcendentals() {
872 // f(x, y) = exp(x²) at (1, 0)
873 // ∂f/∂x = 2x * exp(x²) at x=1 is 2 * e
874 let x = MultiDual::<f64, 2>::variable(1.0, 0);
875 let f = (x * x).exp();
876
877 let e = 1.0_f64.exp();
878 assert!((f.value - e).abs() < 1e-10); // exp(1)
879 assert!((f.derivs[0] - 2.0 * e).abs() < 1e-10); // 2 * exp(1)
880 assert_eq!(f.derivs[1], 0.0);
881
882 // f(x, y, z) = √(x² + y² + z²) at (3, 4, 0)
883 // This is the Euclidean norm
884 // ∂f/∂x = x / √(x² + y² + z²) = 3/5
885 // ∂f/∂y = y / √(x² + y² + z²) = 4/5
886 // ∂f/∂z = z / √(x² + y² + z²) = 0/5 = 0
887 let x = MultiDual::<f64, 3>::variable(3.0, 0);
888 let y = MultiDual::<f64, 3>::variable(4.0, 1);
889 let z = MultiDual::<f64, 3>::variable(0.0, 2);
890 let f = (x * x + y * y + z * z).sqrt();
891
892 assert_eq!(f.value, 5.0); // √(9 + 16 + 0) = 5
893 assert_eq!(f.derivs[0], 0.6); // 3/5
894 assert_eq!(f.derivs[1], 0.8); // 4/5
895 assert_eq!(f.derivs[2], 0.0); // 0/5
896 }
897
898 #[test]
899 fn gradient_quadratic_2d() {
900 use crate::gradient;
901
902 // f(x, y) = x² + 2xy + y² at (3, 4)
903 let f = |vars: [MultiDual<f64, 2>; 2]| {
904 let [x, y] = vars;
905 let two = MultiDual::constant(2.0);
906 x * x + two * x * y + y * y
907 };
908
909 let point = [3.0, 4.0];
910 let (value, grad) = gradient(f, point);
911
912 assert_eq!(value, 49.0); // f(3, 4) = 9 + 24 + 16
913 assert_eq!(grad[0], 14.0); // ∂f/∂x = 2x + 2y = 14
914 assert_eq!(grad[1], 14.0); // ∂f/∂y = 2x + 2y = 14
915 }
916
917 #[test]
918 fn gradient_with_transcendentals() {
919 use crate::gradient;
920
921 // f(x, y, z) = x² + y·exp(z) at (1, 2, 0)
922 let f = |vars: [MultiDual<f64, 3>; 3]| {
923 let [x, y, z] = vars;
924 x * x + y * z.exp()
925 };
926
927 let point = [1.0, 2.0, 0.0];
928 let (value, grad) = gradient(f, point);
929
930 assert_eq!(value, 3.0); // 1 + 2·1 = 3
931 assert_eq!(grad[0], 2.0); // ∂f/∂x = 2x = 2
932 assert_eq!(grad[1], 1.0); // ∂f/∂y = exp(z) = 1
933 assert_eq!(grad[2], 2.0); // ∂f/∂z = y·exp(z) = 2
934 }
935
936 #[test]
937 fn gradient_rosenbrock() {
938 use crate::gradient;
939
940 // Rosenbrock: f(x, y) = (1-x)² + 100(y-x²)²
941 let rosenbrock = |vars: [MultiDual<f64, 2>; 2]| {
942 let [x, y] = vars;
943 let one = MultiDual::constant(1.0);
944 let hundred = MultiDual::constant(100.0);
945
946 let term1 = one - x;
947 let term2 = y - x * x;
948 term1 * term1 + hundred * term2 * term2
949 };
950
951 // Test at global minimum (1, 1)
952 let point = [1.0, 1.0];
953 let (value, grad) = gradient(rosenbrock, point);
954
955 assert_eq!(value, 0.0); // Minimum value is 0
956 assert_eq!(grad[0], 0.0); // Gradient is zero at minimum
957 assert_eq!(grad[1], 0.0);
958
959 // Test at another point (0, 0)
960 let point = [0.0, 0.0];
961 let (value, grad) = gradient(rosenbrock, point);
962
963 assert_eq!(value, 1.0); // f(0, 0) = 1 + 0 = 1
964 assert_eq!(grad[0], -2.0); // ∂f/∂x at (0,0) = -2(1-x) - 400x(y-x²) = -2
965 assert_eq!(grad[1], 0.0); // ∂f/∂y at (0,0) = 200(y-x²) = 0
966 }
967
968 #[test]
969 fn gradient_euclidean_norm() {
970 use crate::gradient;
971
972 // f(x, y, z) = √(x² + y² + z²) at (3, 4, 0)
973 let euclidean_norm = |vars: [MultiDual<f64, 3>; 3]| {
974 let [x, y, z] = vars;
975 (x * x + y * y + z * z).sqrt()
976 };
977
978 let point = [3.0, 4.0, 0.0];
979 let (value, grad) = gradient(euclidean_norm, point);
980
981 assert_eq!(value, 5.0); // √(9 + 16 + 0) = 5
982 assert_eq!(grad[0], 0.6); // x/‖x‖ = 3/5
983 assert_eq!(grad[1], 0.8); // y/‖x‖ = 4/5
984 assert_eq!(grad[2], 0.0); // z/‖x‖ = 0/5
985 }
986
987 #[test]
988 fn gradient_single_variable() {
989 use crate::gradient;
990
991 // f(x) = x³ at x=2
992 // f'(x) = 3x² = 12
993 let f = |vars: [MultiDual<f64, 1>; 1]| {
994 let [x] = vars;
995 x * x * x
996 };
997
998 let point = [2.0];
999 let (value, grad) = gradient(f, point);
1000
1001 assert_eq!(value, 8.0); // 2³ = 8
1002 assert_eq!(grad[0], 12.0); // 3 * 2² = 12
1003 }
1004
1005 #[test]
1006 fn gradient_high_dimensional() {
1007 use crate::gradient;
1008
1009 // f(x₁, x₂, x₃, x₄, x₅) = Σᵢ xᵢ² at (1, 2, 3, 4, 5)
1010 // ∂f/∂xᵢ = 2xᵢ
1011 let f = |vars: [MultiDual<f64, 5>; 5]| {
1012 let [x1, x2, x3, x4, x5] = vars;
1013 x1 * x1 + x2 * x2 + x3 * x3 + x4 * x4 + x5 * x5
1014 };
1015
1016 let point = [1.0, 2.0, 3.0, 4.0, 5.0];
1017 let (value, grad) = gradient(f, point);
1018
1019 assert_eq!(value, 55.0); // 1 + 4 + 9 + 16 + 25
1020 assert_eq!(grad[0], 2.0); // 2 * 1
1021 assert_eq!(grad[1], 4.0); // 2 * 2
1022 assert_eq!(grad[2], 6.0); // 2 * 3
1023 assert_eq!(grad[3], 8.0); // 2 * 4
1024 assert_eq!(grad[4], 10.0); // 2 * 5
1025 }
1026
1027 #[test]
1028 fn gradient_mixed_operations() {
1029 use crate::gradient;
1030
1031 // f(x, y) = sin(x) * exp(y) + ln(x + y) at (1, 0)
1032 let f = |vars: [MultiDual<f64, 2>; 2]| {
1033 let [x, y] = vars;
1034 x.sin() * y.exp() + (x + y).ln()
1035 };
1036
1037 let point = [1.0, 0.0];
1038 let (value, grad) = gradient(f, point);
1039
1040 let sin_1 = 1.0_f64.sin();
1041 let cos_1 = 1.0_f64.cos();
1042
1043 assert!((value - (sin_1 + 0.0)).abs() < 1e-10); // sin(1)*1 + ln(1) = sin(1)
1044 // ∂f/∂x = cos(x)*exp(y) + 1/(x+y) at (1,0) = cos(1) + 1
1045 assert!((grad[0] - (cos_1 + 1.0)).abs() < 1e-10);
1046 // ∂f/∂y = sin(x)*exp(y) + 1/(x+y) at (1,0) = sin(1) + 1
1047 assert!((grad[1] - (sin_1 + 1.0)).abs() < 1e-10);
1048 }
1049}