autodiff/
tape.rs

1//! Tape-based reverse-mode automatic differentiation.
2//!
3//! Reverse-mode AD computes gradients by recording operations during
4//! a forward pass, then propagating gradients backward through the
5//! recorded tape. This is efficient for functions f: ℝⁿ → ℝ where n
6//! is large (e.g., neural networks with millions of parameters).
7//!
8//! # How It Works
9//!
10//! 1. Create a [`Tape`] to record operations
11//! 2. Create variables with [`Tape::var`]
12//! 3. Compute using arithmetic operations (recorded on the tape)
13//! 4. Call [`Var::backward`] to propagate gradients
14//! 5. Query gradients with [`Gradients::get`]
15//!
16//! # Example
17//!
18//! ```
19//! use autodiff::Tape;
20//!
21//! let tape = Tape::new();
22//! let x = tape.var(3.0);
23//! let y = x.clone() * x.clone();  // y = x²
24//!
25//! let grads = y.backward();
26//! assert_eq!(y.value(), 9.0);
27//! assert_eq!(grads.get(&x), 6.0);  // dy/dx = 2x = 6
28//! ```
29//!
30//! # Functional API
31//!
32//! For simple cases, use [`reverse_diff`] or [`reverse_gradient`]:
33//!
34//! ```
35//! use autodiff::reverse_diff;
36//!
37//! let (val, deriv) = reverse_diff(|x| x.clone() * x, 3.0);
38//! assert_eq!(val, 9.0);
39//! assert_eq!(deriv, 6.0);
40//! ```
41//!
42//! # When to Use Reverse-Mode
43//!
44//! - **Reverse-mode** (this module): Efficient when outputs << inputs
45//!   (e.g., loss function with many parameters)
46//! - **Forward-mode** ([`crate::dual`]): Efficient when inputs <<
47//!   outputs (e.g., sensitivity of many outputs to one parameter)
48
49use num_traits::Float;
50use std::cell::RefCell;
51use std::ops::{Add, Div, Mul, Neg, Sub};
52use std::rc::Rc;
53
54/// The computation tape that records operations for reverse-mode AD.
55///
56/// Create a tape with [`Tape::new`], then create variables on it with
57/// [`Tape::var`].
58///
59/// # Examples
60///
61/// ```
62/// use autodiff::Tape;
63///
64/// let tape = Tape::new();
65/// let x = tape.var(3.0);
66/// let y = x.clone() * x.clone();  // y = x²
67///
68/// let grads = y.backward();
69/// assert_eq!(y.value(), 9.0);
70/// assert_eq!(grads.get(&x), 6.0);  // dy/dx = 2x = 6
71/// ```
72#[derive(Clone)]
73pub struct Tape<T> {
74    inner: Rc<RefCell<TapeInner<T>>>,
75}
76
77impl<T> Tape<T> {
78    /// Creates a new empty tape.
79    pub fn new() -> Self {
80        Self {
81            inner: Rc::new(RefCell::new(TapeInner::new())),
82        }
83    }
84}
85
86impl<T: Float> Tape<T> {
87    /// Creates a differentiable variable on this tape.
88    pub fn var(&self, value: T) -> Var<T> {
89        let idx = self.inner.borrow_mut().push_value(value);
90        Var {
91            tape: self.clone(),
92            idx,
93        }
94    }
95
96    fn constant(&self, value: T) -> Var<T> {
97        let idx = self.inner.borrow_mut().push_value(value);
98        Var {
99            tape: self.clone(),
100            idx,
101        }
102    }
103}
104
105/// A differentiable variable for reverse-mode automatic
106/// differentiation.
107///
108/// # Examples
109///
110/// Using [`reverse_diff`] for reusable functions:
111///
112/// ```
113/// use autodiff::{reverse_diff, Var};
114///
115/// // Define f(x) = (x+1)(x-1) = x² - 1
116/// let f = |x: Var<f64>| (x.clone() + 1.0) * (x - 1.0);
117///
118/// let (val, deriv) = reverse_diff(f, 3.0);
119/// assert_eq!(val, 8.0);    // f(3) = 8
120/// assert_eq!(deriv, 6.0);  // f'(3) = 2x = 6
121/// ```
122///
123/// ```
124/// use autodiff::{reverse_diff, Var};
125///
126/// // Define f(x) = x² + x
127/// let f = |x: Var<f64>| x.clone() * x.clone() + x;
128///
129/// let (val, deriv) = reverse_diff(f, 3.0);
130/// assert_eq!(val, 12.0);   // f(3) = 12
131/// assert_eq!(deriv, 7.0);  // f'(3) = 2x + 1 = 7
132/// ```
133#[derive(Clone)]
134pub struct Var<T> {
135    tape: Tape<T>,
136    idx: usize,
137}
138
139impl<T: Copy> Var<T> {
140    /// Returns the value of this variable.
141    pub fn value(&self) -> T {
142        self.tape.inner.borrow().vals[self.idx]
143    }
144}
145
146impl<T: Float> Var<T> {
147    /// Computes gradients by backpropagation from this variable.
148    ///
149    /// Returns a [`Gradients`] object that can be queried for the
150    /// gradient with respect to any variable.
151    ///
152    /// # Examples
153    ///
154    /// ```
155    /// use autodiff::Tape;
156    ///
157    /// let tape = Tape::new();
158    /// let x = tape.var(3.0);
159    /// let y = x.clone() * x.clone();  // y = x²
160    ///
161    /// let grads = y.backward();
162    /// assert_eq!(grads.get(&x), 6.0);  // dy/dx = 2x = 6
163    /// ```
164    pub fn backward(&self) -> Gradients<T> {
165        self.tape.inner.borrow_mut().backward_from(self.idx);
166        Gradients {
167            tape: self.tape.clone(),
168        }
169    }
170
171    /// Computes the reciprocal `1/self`.
172    pub fn recip(self) -> Self {
173        unary(self, OpKind::Recip, |a| T::one() / a)
174    }
175
176    /// Computes `e^self`.
177    pub fn exp(self) -> Self {
178        unary(self, OpKind::Exp, |a| a.exp())
179    }
180
181    /// Computes `sin(self)`.
182    pub fn sin(self) -> Self {
183        unary(self, OpKind::Sin, |a| a.sin())
184    }
185
186    /// Computes `cos(self)`.
187    pub fn cos(self) -> Self {
188        unary(self, OpKind::Cos, |a| a.cos())
189    }
190
191    /// Computes `ln(self)`.
192    pub fn ln(self) -> Self {
193        unary(self, OpKind::Ln, |a| a.ln())
194    }
195
196    /// Computes `sqrt(self)`.
197    pub fn sqrt(self) -> Self {
198        unary(self, OpKind::Sqrt, |a| a.sqrt())
199    }
200}
201
202/// The gradients computed by [`Var::backward`].
203///
204/// Query individual gradients using [`get`](Gradients::get).
205///
206/// # Example
207///
208/// ```
209/// use autodiff::Tape;
210///
211/// let tape = Tape::new();
212/// let x = tape.var(3.0);
213/// let y = tape.var(4.0);
214/// let z = x.clone() * y.clone();  // z = x * y
215///
216/// let grads = z.backward();
217/// assert_eq!(grads.get(&x), 4.0);  // dz/dx = y
218/// assert_eq!(grads.get(&y), 3.0);  // dz/dy = x
219/// ```
220pub struct Gradients<T> {
221    tape: Tape<T>,
222}
223
224impl<T: Copy> Gradients<T> {
225    /// Returns the gradient with respect to the given variable.
226    pub fn get(&self, var: &Var<T>) -> T {
227        self.tape.inner.borrow().grads[var.idx]
228    }
229}
230
231struct TapeInner<T> {
232    vals: Vec<T>,
233    grads: Vec<T>,
234    ops: Vec<Op>,
235}
236
237impl<T> TapeInner<T> {
238    fn new() -> Self {
239        Self {
240            vals: Vec::new(),
241            grads: Vec::new(),
242            ops: Vec::new(),
243        }
244    }
245
246    fn push_op(&mut self, op: Op) {
247        self.ops.push(op)
248    }
249}
250
251impl<T: Float> TapeInner<T> {
252    fn push_value(&mut self, v: T) -> usize {
253        let idx = self.vals.len();
254        self.vals.push(v);
255        self.grads.push(T::zero());
256        idx
257    }
258
259    fn backward_from(&mut self, out: usize) {
260        for g in &mut self.grads {
261            *g = T::zero();
262        }
263        self.grads[out] = T::one();
264
265        for op in self.ops.iter().rev() {
266            let go = self.grads[op.out];
267
268            match op.kind {
269                OpKind::Add => {
270                    self.grads[op.a] = self.grads[op.a] + go;
271                    self.grads[op.b] = self.grads[op.b] + go;
272                }
273                OpKind::Sub => {
274                    self.grads[op.a] = self.grads[op.a] + go;
275                    self.grads[op.b] = self.grads[op.b] - go;
276                }
277                OpKind::Mul => {
278                    let a = self.vals[op.a];
279                    let b = self.vals[op.b];
280                    self.grads[op.a] = self.grads[op.a] + go * b;
281                    self.grads[op.b] = self.grads[op.b] + go * a;
282                }
283                OpKind::Div => {
284                    let a = self.vals[op.a];
285                    let b = self.vals[op.b];
286                    self.grads[op.a] = self.grads[op.a] + go / b;
287                    self.grads[op.b] = self.grads[op.b] + go * (-(a) / (b * b));
288                }
289                OpKind::Neg => {
290                    self.grads[op.a] = self.grads[op.a] - go;
291                }
292                OpKind::Recip => {
293                    let a = self.vals[op.a];
294                    self.grads[op.a] = self.grads[op.a] + go * (-(T::one()) / (a * a));
295                }
296                OpKind::Exp => {
297                    let out = self.vals[op.out];
298                    self.grads[op.a] = self.grads[op.a] + go * out;
299                }
300                OpKind::Sin => {
301                    let a = self.vals[op.a];
302                    self.grads[op.a] = self.grads[op.a] + go * a.cos();
303                }
304                OpKind::Cos => {
305                    let a = self.vals[op.a];
306                    self.grads[op.a] = self.grads[op.a] - go * a.sin();
307                }
308                OpKind::Ln => {
309                    let a = self.vals[op.a];
310                    self.grads[op.a] = self.grads[op.a] + go / a;
311                }
312                OpKind::Sqrt => {
313                    let out = self.vals[op.out];
314                    self.grads[op.a] = self.grads[op.a] + go / (out + out);
315                }
316            }
317        }
318    }
319}
320
321struct Op {
322    kind: OpKind,
323    out: usize,
324    a: usize,
325    b: usize,
326}
327
328enum OpKind {
329    Add,
330    Sub,
331    Mul,
332    Div,
333    Neg,
334    Recip,
335    Exp,
336    Sin,
337    Cos,
338    Ln,
339    Sqrt,
340}
341
342fn unary<T: Float>(x: Var<T>, kind: OpKind, f: impl FnOnce(T) -> T) -> Var<T> {
343    let tape = x.tape.clone();
344    let outv = {
345        let t = tape.inner.borrow();
346        f(t.vals[x.idx])
347    };
348
349    let out = {
350        let mut t = tape.inner.borrow_mut();
351        let out = t.push_value(outv);
352        t.push_op(Op {
353            kind,
354            out,
355            a: x.idx,
356            b: 0,
357        });
358        out
359    };
360
361    Var { tape, idx: out }
362}
363
364fn binary<T: Float>(lhs: Var<T>, rhs: Var<T>, kind: OpKind, f: impl FnOnce(T, T) -> T) -> Var<T> {
365    assert!(
366        Rc::ptr_eq(&lhs.tape.inner, &rhs.tape.inner),
367        "Vars must share a tape"
368    );
369    let tape = lhs.tape.clone();
370
371    let (a, b, outv) = {
372        let t = tape.inner.borrow();
373        let av = t.vals[lhs.idx];
374        let bv = t.vals[rhs.idx];
375        (lhs.idx, rhs.idx, f(av, bv))
376    };
377
378    let out = {
379        let mut t = tape.inner.borrow_mut();
380        let out = t.push_value(outv);
381        t.push_op(Op { kind, out, a, b });
382        out
383    };
384
385    Var { tape, idx: out }
386}
387
388impl<T: Float> Add for Var<T> {
389    type Output = Var<T>;
390    fn add(self, rhs: Self) -> Self::Output {
391        binary(self, rhs, OpKind::Add, |a, b| a + b)
392    }
393}
394
395impl<T: Float> Sub for Var<T> {
396    type Output = Var<T>;
397    fn sub(self, rhs: Self) -> Self::Output {
398        binary(self, rhs, OpKind::Sub, |a, b| a - b)
399    }
400}
401
402impl<T: Float> Mul for Var<T> {
403    type Output = Var<T>;
404    fn mul(self, rhs: Self) -> Self::Output {
405        binary(self, rhs, OpKind::Mul, |a, b| a * b)
406    }
407}
408
409impl<T: Float> Div for Var<T> {
410    type Output = Var<T>;
411    fn div(self, rhs: Self) -> Self::Output {
412        binary(self, rhs, OpKind::Div, |a, b| a / b)
413    }
414}
415
416impl<T: Float> Neg for Var<T> {
417    type Output = Var<T>;
418    fn neg(self) -> Self::Output {
419        unary(self, OpKind::Neg, |a| -a)
420    }
421}
422
423impl<T: Float> Add<T> for Var<T> {
424    type Output = Var<T>;
425    fn add(self, c: T) -> Self::Output {
426        let cvar = self.tape.constant(c);
427        self + cvar
428    }
429}
430
431impl<T: Float> Sub<T> for Var<T> {
432    type Output = Var<T>;
433    fn sub(self, c: T) -> Self::Output {
434        let cvar = self.tape.constant(c);
435        self - cvar
436    }
437}
438
439impl<T: Float> Mul<T> for Var<T> {
440    type Output = Var<T>;
441    fn mul(self, c: T) -> Self::Output {
442        let cvar = self.tape.constant(c);
443        self * cvar
444    }
445}
446
447impl<T: Float> Div<T> for Var<T> {
448    type Output = Var<T>;
449    fn div(self, c: T) -> Self::Output {
450        let cvar = self.tape.constant(c);
451        self / cvar
452    }
453}
454
455/// Computes the value and derivative of a function using reverse-mode
456/// AD.
457///
458/// This is the reverse-mode equivalent of forward-mode
459/// differentiation. The function `f` is evaluated at `x`, and both
460/// the value and derivative are returned.
461///
462/// # Examples
463///
464/// ```
465/// use autodiff::reverse_diff;
466///
467/// // f(x) = x² at x = 3
468/// let (val, deriv) = reverse_diff(|x| x.clone() * x, 3.0);
469/// assert_eq!(val, 9.0);    // f(3) = 9
470/// assert_eq!(deriv, 6.0);  // f'(3) = 2x = 6
471/// ```
472///
473/// Reuse the same function at different points:
474///
475/// ```
476/// use autodiff::{reverse_diff, Var};
477///
478/// let f = |x: Var<f64>| x.clone() * x.clone() - x;
479///
480/// let (v1, d1) = reverse_diff(f, 2.0);
481/// let (v2, d2) = reverse_diff(f, 5.0);
482///
483/// assert_eq!((v1, d1), (2.0, 3.0));   // f(2) = 2, f'(2) = 3
484/// assert_eq!((v2, d2), (20.0, 9.0));  // f(5) = 20, f'(5) = 9
485/// ```
486pub fn reverse_diff<T, F>(f: F, x: T) -> (T, T)
487where
488    T: Float,
489    F: FnOnce(Var<T>) -> Var<T>,
490{
491    let tape = Tape::new();
492    let var = tape.var(x);
493    let var_clone = var.clone();
494    let result = f(var);
495    let grads = result.backward();
496    (result.value(), grads.get(&var_clone))
497}
498
499/// Computes the value and gradient of a multivariable function using
500/// reverse-mode AD.
501///
502/// This is the reverse-mode equivalent of
503/// [`gradient`](crate::gradient) for functions f: ℝⁿ → ℝ.
504///
505/// # Examples
506///
507/// ```
508/// use autodiff::{reverse_gradient, Var};
509///
510/// // f(x, y) = x² + x*y at (3, 4)
511/// let f = |[x, y]: [Var<f64>; 2]| x.clone() * x.clone() + x * y;
512///
513/// let (val, grad) = reverse_gradient(f, [3.0, 4.0]);
514/// assert_eq!(val, 21.0);       // f(3, 4) = 9 + 12 = 21
515/// assert_eq!(grad[0], 10.0);   // ∂f/∂x = 2x + y = 10
516/// assert_eq!(grad[1], 3.0);    // ∂f/∂y = x = 3
517/// ```
518pub fn reverse_gradient<T, F, const N: usize>(f: F, point: [T; N]) -> (T, [T; N])
519where
520    T: Float,
521    F: FnOnce([Var<T>; N]) -> Var<T>,
522{
523    let tape = Tape::new();
524    let vars: [Var<T>; N] = std::array::from_fn(|i| tape.var(point[i]));
525    let vars_clone = vars.clone();
526    let result = f(vars);
527    let grads = result.backward();
528    (
529        result.value(),
530        std::array::from_fn(|i| grads.get(&vars_clone[i])),
531    )
532}
533
534#[cfg(test)]
535mod tests {
536    use super::*;
537
538    // === Basic operations ===
539
540    #[test]
541    fn addition_works() {
542        // f(x) = x + 5 at x=3
543        let (val, deriv) = reverse_diff(|x| x + 5.0, 3.0);
544        assert_eq!(val, 8.0);
545        assert_eq!(deriv, 1.0);
546    }
547
548    #[test]
549    fn subtraction_works() {
550        // f(x) = x - 2 at x=5
551        let (val, deriv) = reverse_diff(|x| x - 2.0, 5.0);
552        assert_eq!(val, 3.0);
553        assert_eq!(deriv, 1.0);
554    }
555
556    #[test]
557    fn multiplication_works() {
558        // f(x) = x * x = x² at x=3
559        let (val, deriv) = reverse_diff(|x| x.clone() * x, 3.0);
560        assert_eq!(val, 9.0);
561        assert_eq!(deriv, 6.0); // d/dx(x²) = 2x = 6
562    }
563
564    #[test]
565    fn division_works() {
566        // f(x) = x / 2 at x=6
567        let (val, deriv) = reverse_diff(|x| x / 2.0, 6.0);
568        assert_eq!(val, 3.0);
569        assert_eq!(deriv, 0.5); // d/dx(x/2) = 0.5
570    }
571
572    #[test]
573    fn negation_works() {
574        // f(x) = -x at x=3
575        let (val, deriv) = reverse_diff(|x| -x, 3.0);
576        assert_eq!(val, -3.0);
577        assert_eq!(deriv, -1.0);
578    }
579
580    // === Var-Var operations ===
581
582    #[test]
583    fn var_var_addition() {
584        // f(x) = x + x = 2x at x=3
585        let (val, deriv) = reverse_diff(|x| x.clone() + x, 3.0);
586        assert_eq!(val, 6.0);
587        assert_eq!(deriv, 2.0);
588    }
589
590    #[test]
591    fn var_var_subtraction() {
592        // f(x) = x - x = 0 at x=3
593        let (val, deriv) = reverse_diff(|x| x.clone() - x, 3.0);
594        assert_eq!(val, 0.0);
595        assert_eq!(deriv, 0.0);
596    }
597
598    #[test]
599    fn var_var_division() {
600        // f(x) = x / x = 1 at x=3
601        // f'(x) = 0 (derivative of constant 1)
602        let (val, deriv) = reverse_diff(|x| x.clone() / x, 3.0);
603        assert_eq!(val, 1.0);
604        assert!((deriv - 0.0).abs() < 1e-10);
605    }
606
607    // === Scalar operations (Var op T) ===
608
609    #[test]
610    fn var_add_scalar() {
611        let (val, deriv) = reverse_diff(|x| x + 10.0, 5.0);
612        assert_eq!(val, 15.0);
613        assert_eq!(deriv, 1.0);
614    }
615
616    #[test]
617    fn var_sub_scalar() {
618        let (val, deriv) = reverse_diff(|x| x - 3.0, 10.0);
619        assert_eq!(val, 7.0);
620        assert_eq!(deriv, 1.0);
621    }
622
623    #[test]
624    fn var_mul_scalar() {
625        let (val, deriv) = reverse_diff(|x| x * 3.0, 4.0);
626        assert_eq!(val, 12.0);
627        assert_eq!(deriv, 3.0);
628    }
629
630    #[test]
631    fn var_div_scalar() {
632        let (val, deriv) = reverse_diff(|x| x / 4.0, 12.0);
633        assert_eq!(val, 3.0);
634        assert_eq!(deriv, 0.25);
635    }
636
637    // === Unary functions ===
638
639    #[test]
640    fn recip_works() {
641        // f(x) = 1/x at x=2
642        let (val, deriv) = reverse_diff(|x| x.recip(), 2.0);
643        assert_eq!(val, 0.5);
644        assert_eq!(deriv, -0.25); // d/dx(1/x) = -1/x² = -0.25
645    }
646
647    #[test]
648    fn exp_works() {
649        // f(x) = e^x at x=0
650        let (val, deriv) = reverse_diff(|x| x.exp(), 0.0);
651        assert_eq!(val, 1.0);
652        assert_eq!(deriv, 1.0); // d/dx(e^x) = e^x = 1 at x=0
653    }
654
655    #[test]
656    fn exp_at_one() {
657        // f(x) = e^x at x=1
658        let (val, deriv) = reverse_diff(|x| x.exp(), 1.0);
659        let e = 1.0_f64.exp();
660        assert!((val - e).abs() < 1e-10);
661        assert!((deriv - e).abs() < 1e-10);
662    }
663
664    #[test]
665    fn sin_works() {
666        // f(x) = sin(x) at x=0
667        let (val, deriv) = reverse_diff(|x| x.sin(), 0.0);
668        assert_eq!(val, 0.0);
669        assert_eq!(deriv, 1.0); // cos(0) = 1
670    }
671
672    #[test]
673    fn cos_works() {
674        // f(x) = cos(x) at x=0
675        let (val, deriv) = reverse_diff(|x| x.cos(), 0.0);
676        assert_eq!(val, 1.0);
677        assert_eq!(deriv, 0.0); // -sin(0) = 0
678    }
679
680    #[test]
681    fn ln_works() {
682        // f(x) = ln(x) at x=1
683        let (val, deriv) = reverse_diff(|x| x.ln(), 1.0);
684        assert_eq!(val, 0.0);
685        assert_eq!(deriv, 1.0); // 1/x = 1 at x=1
686    }
687
688    #[test]
689    fn ln_at_e() {
690        // f(x) = ln(x) at x=e
691        let e = 1.0_f64.exp();
692        let (val, deriv) = reverse_diff(|x| x.ln(), e);
693        assert!((val - 1.0).abs() < 1e-10);
694        assert!((deriv - 1.0 / e).abs() < 1e-10);
695    }
696
697    #[test]
698    fn sqrt_works() {
699        // f(x) = √x at x=4
700        let (val, deriv) = reverse_diff(|x| x.sqrt(), 4.0);
701        assert_eq!(val, 2.0);
702        assert_eq!(deriv, 0.25); // 1/(2√x) = 1/4
703    }
704
705    // === Fan-out (variable used multiple times) ===
706
707    #[test]
708    fn fan_out_addition() {
709        // f(x) = x + x + x = 3x at x=2
710        let (val, deriv) = reverse_diff(|x| x.clone() + x.clone() + x, 2.0);
711        assert_eq!(val, 6.0);
712        assert_eq!(deriv, 3.0);
713    }
714
715    #[test]
716    fn fan_out_mixed() {
717        // f(x) = x² + x at x=3
718        // f'(x) = 2x + 1 = 7
719        let (val, deriv) = reverse_diff(|x| x.clone() * x.clone() + x, 3.0);
720        assert_eq!(val, 12.0);
721        assert_eq!(deriv, 7.0);
722    }
723
724    #[test]
725    fn fan_out_product() {
726        // f(x) = x * x * x = x³ at x=2
727        // f'(x) = 3x² = 12
728        let (val, deriv) = reverse_diff(|x| x.clone() * x.clone() * x, 2.0);
729        assert_eq!(val, 8.0);
730        assert_eq!(deriv, 12.0);
731    }
732
733    // === Complex expressions ===
734
735    #[test]
736    fn polynomial() {
737        // f(x) = x³ - 2x + 1 at x=2
738        // f'(x) = 3x² - 2 = 10
739        let (val, deriv) = reverse_diff(
740            |x| {
741                let x2 = x.clone() * x.clone();
742                let x3 = x2 * x.clone();
743                x3 - x * 2.0 + 1.0
744            },
745            2.0,
746        );
747        assert_eq!(val, 5.0);
748        assert_eq!(deriv, 10.0);
749    }
750
751    #[test]
752    fn quotient_rule() {
753        // f(x) = (x+1)/(x+2) at x=3
754        // f(3) = 4/5 = 0.8
755        // f'(x) = 1/(x+2)² = 1/25 = 0.04
756        let (val, deriv) = reverse_diff(
757            |x| {
758                let num = x.clone() + 1.0;
759                let den = x + 2.0;
760                num / den
761            },
762            3.0,
763        );
764        assert_eq!(val, 0.8);
765        assert!((deriv - 0.04).abs() < 1e-10);
766    }
767
768    #[test]
769    fn chain_rule_product() {
770        // f(x) = (x+1)(x-1) = x² - 1 at x=3
771        // f'(x) = 2x = 6
772        let (val, deriv) = reverse_diff(|x| (x.clone() + 1.0) * (x - 1.0), 3.0);
773        assert_eq!(val, 8.0);
774        assert_eq!(deriv, 6.0);
775    }
776
777    #[test]
778    fn chain_rule_transcendental() {
779        // f(x) = sin(2x) at x=0
780        // f'(x) = 2*cos(2x) = 2 at x=0
781        let (val, deriv) = reverse_diff(|x| (x * 2.0).sin(), 0.0);
782        assert_eq!(val, 0.0);
783        assert_eq!(deriv, 2.0);
784    }
785
786    #[test]
787    fn exp_of_square() {
788        // f(x) = e^(x²) at x=1
789        // f'(x) = 2x * e^(x²) = 2e at x=1
790        let (val, deriv) = reverse_diff(|x| (x.clone() * x).exp(), 1.0);
791        let e = 1.0_f64.exp();
792        assert!((val - e).abs() < 1e-10);
793        assert!((deriv - 2.0 * e).abs() < 1e-10);
794    }
795
796    #[test]
797    fn ln_of_square() {
798        // f(x) = ln(x²) = 2*ln(x) at x=e
799        // f'(x) = 2/x
800        let e = 1.0_f64.exp();
801        let (val, deriv) = reverse_diff(|x| (x.clone() * x).ln(), e);
802        assert!((val - 2.0).abs() < 1e-10);
803        assert!((deriv - 2.0 / e).abs() < 1e-10);
804    }
805
806    #[test]
807    fn sqrt_of_sum() {
808        // f(x) = √(x+5) at x=4
809        // f'(x) = 1/(2√(x+5)) = 1/6
810        let (val, deriv) = reverse_diff(|x| (x + 5.0).sqrt(), 4.0);
811        assert_eq!(val, 3.0);
812        assert!((deriv - 1.0 / 6.0).abs() < 1e-10);
813    }
814
815    // === reverse_gradient tests ===
816
817    #[test]
818    fn gradient_sum() {
819        // f(x, y) = x + y at (3, 4)
820        // ∂f/∂x = 1, ∂f/∂y = 1
821        let (val, grad) = reverse_gradient(|[x, y]| x + y, [3.0, 4.0]);
822        assert_eq!(val, 7.0);
823        assert_eq!(grad, [1.0, 1.0]);
824    }
825
826    #[test]
827    fn gradient_product() {
828        // f(x, y) = x * y at (3, 4)
829        // ∂f/∂x = y = 4, ∂f/∂y = x = 3
830        let (val, grad) = reverse_gradient(|[x, y]| x * y, [3.0, 4.0]);
831        assert_eq!(val, 12.0);
832        assert_eq!(grad, [4.0, 3.0]);
833    }
834
835    #[test]
836    fn gradient_mixed() {
837        // f(x, y) = x² + x*y at (3, 4)
838        // ∂f/∂x = 2x + y = 10, ∂f/∂y = x = 3
839        let (val, grad) = reverse_gradient(|[x, y]| x.clone() * x.clone() + x * y, [3.0, 4.0]);
840        assert_eq!(val, 21.0);
841        assert_eq!(grad, [10.0, 3.0]);
842    }
843
844    #[test]
845    fn gradient_three_vars() {
846        // f(x, y, z) = x*y + y*z + z*x at (1, 2, 3)
847        // f = 2 + 6 + 3 = 11
848        // ∂f/∂x = y + z = 5
849        // ∂f/∂y = x + z = 4
850        // ∂f/∂z = y + x = 3
851        let (val, grad) = reverse_gradient(
852            |[x, y, z]| x.clone() * y.clone() + y.clone() * z.clone() + z * x,
853            [1.0, 2.0, 3.0],
854        );
855        assert_eq!(val, 11.0);
856        assert_eq!(grad, [5.0, 4.0, 3.0]);
857    }
858
859    #[test]
860    fn gradient_with_transcendental() {
861        // f(x, y) = sin(x) * cos(y) at (0, 0)
862        // f = 0 * 1 = 0
863        // ∂f/∂x = cos(x) * cos(y) = 1
864        // ∂f/∂y = sin(x) * (-sin(y)) = 0
865        let (val, grad) = reverse_gradient(|[x, y]| x.sin() * y.cos(), [0.0, 0.0]);
866        assert_eq!(val, 0.0);
867        assert_eq!(grad, [1.0, 0.0]);
868    }
869
870    // === Reusable function pattern ===
871
872    #[test]
873    fn reusable_function() {
874        let f = |x: Var<f64>| x.clone() * x.clone() + x * 2.0;
875
876        // Evaluate at multiple points
877        let (v1, d1) = reverse_diff(f, 3.0);
878        assert_eq!(v1, 15.0); // 9 + 6
879        assert_eq!(d1, 8.0); // 2*3 + 2
880
881        let (v2, d2) = reverse_diff(f, 5.0);
882        assert_eq!(v2, 35.0); // 25 + 10
883        assert_eq!(d2, 12.0); // 2*5 + 2
884
885        let (v3, d3) = reverse_diff(f, 0.0);
886        assert_eq!(v3, 0.0);
887        assert_eq!(d3, 2.0);
888    }
889
890    // === Comparison with forward-mode ===
891
892    #[test]
893    fn matches_forward_mode_polynomial() {
894        // Verify reverse-mode matches forward-mode for x³ + 2x² - x at x=2
895        use crate::dual::Dual;
896
897        let forward = {
898            let x = Dual::variable(2.0);
899            let y = x * x * x + Dual::constant(2.0) * x * x - x;
900            (y.value, y.deriv)
901        };
902
903        let reverse = reverse_diff(
904            |x| {
905                let x2 = x.clone() * x.clone();
906                let x3 = x2.clone() * x.clone();
907                x3 + x2 * 2.0 - x
908            },
909            2.0,
910        );
911
912        assert_eq!(forward.0, reverse.0);
913        assert!((forward.1 - reverse.1).abs() < 1e-10);
914    }
915
916    #[test]
917    fn matches_forward_mode_transcendental() {
918        // Verify reverse-mode matches forward-mode for e^(sin(x)) at x=1
919        use crate::dual::Dual;
920
921        let forward = {
922            let x = Dual::variable(1.0);
923            let y = x.sin().exp();
924            (y.value, y.deriv)
925        };
926
927        let reverse = reverse_diff(|x| x.sin().exp(), 1.0);
928
929        assert!((forward.0 - reverse.0).abs() < 1e-10);
930        assert!((forward.1 - reverse.1).abs() < 1e-10);
931    }
932}