1use num_traits::Float;
50use std::cell::RefCell;
51use std::ops::{Add, Div, Mul, Neg, Sub};
52use std::rc::Rc;
53
54#[derive(Clone)]
73pub struct Tape<T> {
74 inner: Rc<RefCell<TapeInner<T>>>,
75}
76
77impl<T> Tape<T> {
78 pub fn new() -> Self {
80 Self {
81 inner: Rc::new(RefCell::new(TapeInner::new())),
82 }
83 }
84}
85
86impl<T: Float> Tape<T> {
87 pub fn var(&self, value: T) -> Var<T> {
89 let idx = self.inner.borrow_mut().push_value(value);
90 Var {
91 tape: self.clone(),
92 idx,
93 }
94 }
95
96 fn constant(&self, value: T) -> Var<T> {
97 let idx = self.inner.borrow_mut().push_value(value);
98 Var {
99 tape: self.clone(),
100 idx,
101 }
102 }
103}
104
105#[derive(Clone)]
134pub struct Var<T> {
135 tape: Tape<T>,
136 idx: usize,
137}
138
139impl<T: Copy> Var<T> {
140 pub fn value(&self) -> T {
142 self.tape.inner.borrow().vals[self.idx]
143 }
144}
145
146impl<T: Float> Var<T> {
147 pub fn backward(&self) -> Gradients<T> {
165 self.tape.inner.borrow_mut().backward_from(self.idx);
166 Gradients {
167 tape: self.tape.clone(),
168 }
169 }
170
171 pub fn recip(self) -> Self {
173 unary(self, OpKind::Recip, |a| T::one() / a)
174 }
175
176 pub fn exp(self) -> Self {
178 unary(self, OpKind::Exp, |a| a.exp())
179 }
180
181 pub fn sin(self) -> Self {
183 unary(self, OpKind::Sin, |a| a.sin())
184 }
185
186 pub fn cos(self) -> Self {
188 unary(self, OpKind::Cos, |a| a.cos())
189 }
190
191 pub fn ln(self) -> Self {
193 unary(self, OpKind::Ln, |a| a.ln())
194 }
195
196 pub fn sqrt(self) -> Self {
198 unary(self, OpKind::Sqrt, |a| a.sqrt())
199 }
200}
201
202pub struct Gradients<T> {
221 tape: Tape<T>,
222}
223
224impl<T: Copy> Gradients<T> {
225 pub fn get(&self, var: &Var<T>) -> T {
227 self.tape.inner.borrow().grads[var.idx]
228 }
229}
230
231struct TapeInner<T> {
232 vals: Vec<T>,
233 grads: Vec<T>,
234 ops: Vec<Op>,
235}
236
237impl<T> TapeInner<T> {
238 fn new() -> Self {
239 Self {
240 vals: Vec::new(),
241 grads: Vec::new(),
242 ops: Vec::new(),
243 }
244 }
245
246 fn push_op(&mut self, op: Op) {
247 self.ops.push(op)
248 }
249}
250
251impl<T: Float> TapeInner<T> {
252 fn push_value(&mut self, v: T) -> usize {
253 let idx = self.vals.len();
254 self.vals.push(v);
255 self.grads.push(T::zero());
256 idx
257 }
258
259 fn backward_from(&mut self, out: usize) {
260 for g in &mut self.grads {
261 *g = T::zero();
262 }
263 self.grads[out] = T::one();
264
265 for op in self.ops.iter().rev() {
266 let go = self.grads[op.out];
267
268 match op.kind {
269 OpKind::Add => {
270 self.grads[op.a] = self.grads[op.a] + go;
271 self.grads[op.b] = self.grads[op.b] + go;
272 }
273 OpKind::Sub => {
274 self.grads[op.a] = self.grads[op.a] + go;
275 self.grads[op.b] = self.grads[op.b] - go;
276 }
277 OpKind::Mul => {
278 let a = self.vals[op.a];
279 let b = self.vals[op.b];
280 self.grads[op.a] = self.grads[op.a] + go * b;
281 self.grads[op.b] = self.grads[op.b] + go * a;
282 }
283 OpKind::Div => {
284 let a = self.vals[op.a];
285 let b = self.vals[op.b];
286 self.grads[op.a] = self.grads[op.a] + go / b;
287 self.grads[op.b] = self.grads[op.b] + go * (-(a) / (b * b));
288 }
289 OpKind::Neg => {
290 self.grads[op.a] = self.grads[op.a] - go;
291 }
292 OpKind::Recip => {
293 let a = self.vals[op.a];
294 self.grads[op.a] = self.grads[op.a] + go * (-(T::one()) / (a * a));
295 }
296 OpKind::Exp => {
297 let out = self.vals[op.out];
298 self.grads[op.a] = self.grads[op.a] + go * out;
299 }
300 OpKind::Sin => {
301 let a = self.vals[op.a];
302 self.grads[op.a] = self.grads[op.a] + go * a.cos();
303 }
304 OpKind::Cos => {
305 let a = self.vals[op.a];
306 self.grads[op.a] = self.grads[op.a] - go * a.sin();
307 }
308 OpKind::Ln => {
309 let a = self.vals[op.a];
310 self.grads[op.a] = self.grads[op.a] + go / a;
311 }
312 OpKind::Sqrt => {
313 let out = self.vals[op.out];
314 self.grads[op.a] = self.grads[op.a] + go / (out + out);
315 }
316 }
317 }
318 }
319}
320
321struct Op {
322 kind: OpKind,
323 out: usize,
324 a: usize,
325 b: usize,
326}
327
328enum OpKind {
329 Add,
330 Sub,
331 Mul,
332 Div,
333 Neg,
334 Recip,
335 Exp,
336 Sin,
337 Cos,
338 Ln,
339 Sqrt,
340}
341
342fn unary<T: Float>(x: Var<T>, kind: OpKind, f: impl FnOnce(T) -> T) -> Var<T> {
343 let tape = x.tape.clone();
344 let outv = {
345 let t = tape.inner.borrow();
346 f(t.vals[x.idx])
347 };
348
349 let out = {
350 let mut t = tape.inner.borrow_mut();
351 let out = t.push_value(outv);
352 t.push_op(Op {
353 kind,
354 out,
355 a: x.idx,
356 b: 0,
357 });
358 out
359 };
360
361 Var { tape, idx: out }
362}
363
364fn binary<T: Float>(lhs: Var<T>, rhs: Var<T>, kind: OpKind, f: impl FnOnce(T, T) -> T) -> Var<T> {
365 assert!(
366 Rc::ptr_eq(&lhs.tape.inner, &rhs.tape.inner),
367 "Vars must share a tape"
368 );
369 let tape = lhs.tape.clone();
370
371 let (a, b, outv) = {
372 let t = tape.inner.borrow();
373 let av = t.vals[lhs.idx];
374 let bv = t.vals[rhs.idx];
375 (lhs.idx, rhs.idx, f(av, bv))
376 };
377
378 let out = {
379 let mut t = tape.inner.borrow_mut();
380 let out = t.push_value(outv);
381 t.push_op(Op { kind, out, a, b });
382 out
383 };
384
385 Var { tape, idx: out }
386}
387
388impl<T: Float> Add for Var<T> {
389 type Output = Var<T>;
390 fn add(self, rhs: Self) -> Self::Output {
391 binary(self, rhs, OpKind::Add, |a, b| a + b)
392 }
393}
394
395impl<T: Float> Sub for Var<T> {
396 type Output = Var<T>;
397 fn sub(self, rhs: Self) -> Self::Output {
398 binary(self, rhs, OpKind::Sub, |a, b| a - b)
399 }
400}
401
402impl<T: Float> Mul for Var<T> {
403 type Output = Var<T>;
404 fn mul(self, rhs: Self) -> Self::Output {
405 binary(self, rhs, OpKind::Mul, |a, b| a * b)
406 }
407}
408
409impl<T: Float> Div for Var<T> {
410 type Output = Var<T>;
411 fn div(self, rhs: Self) -> Self::Output {
412 binary(self, rhs, OpKind::Div, |a, b| a / b)
413 }
414}
415
416impl<T: Float> Neg for Var<T> {
417 type Output = Var<T>;
418 fn neg(self) -> Self::Output {
419 unary(self, OpKind::Neg, |a| -a)
420 }
421}
422
423impl<T: Float> Add<T> for Var<T> {
424 type Output = Var<T>;
425 fn add(self, c: T) -> Self::Output {
426 let cvar = self.tape.constant(c);
427 self + cvar
428 }
429}
430
431impl<T: Float> Sub<T> for Var<T> {
432 type Output = Var<T>;
433 fn sub(self, c: T) -> Self::Output {
434 let cvar = self.tape.constant(c);
435 self - cvar
436 }
437}
438
439impl<T: Float> Mul<T> for Var<T> {
440 type Output = Var<T>;
441 fn mul(self, c: T) -> Self::Output {
442 let cvar = self.tape.constant(c);
443 self * cvar
444 }
445}
446
447impl<T: Float> Div<T> for Var<T> {
448 type Output = Var<T>;
449 fn div(self, c: T) -> Self::Output {
450 let cvar = self.tape.constant(c);
451 self / cvar
452 }
453}
454
455pub fn reverse_diff<T, F>(f: F, x: T) -> (T, T)
487where
488 T: Float,
489 F: FnOnce(Var<T>) -> Var<T>,
490{
491 let tape = Tape::new();
492 let var = tape.var(x);
493 let var_clone = var.clone();
494 let result = f(var);
495 let grads = result.backward();
496 (result.value(), grads.get(&var_clone))
497}
498
499pub fn reverse_gradient<T, F, const N: usize>(f: F, point: [T; N]) -> (T, [T; N])
519where
520 T: Float,
521 F: FnOnce([Var<T>; N]) -> Var<T>,
522{
523 let tape = Tape::new();
524 let vars: [Var<T>; N] = std::array::from_fn(|i| tape.var(point[i]));
525 let vars_clone = vars.clone();
526 let result = f(vars);
527 let grads = result.backward();
528 (
529 result.value(),
530 std::array::from_fn(|i| grads.get(&vars_clone[i])),
531 )
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537
538 #[test]
541 fn addition_works() {
542 let (val, deriv) = reverse_diff(|x| x + 5.0, 3.0);
544 assert_eq!(val, 8.0);
545 assert_eq!(deriv, 1.0);
546 }
547
548 #[test]
549 fn subtraction_works() {
550 let (val, deriv) = reverse_diff(|x| x - 2.0, 5.0);
552 assert_eq!(val, 3.0);
553 assert_eq!(deriv, 1.0);
554 }
555
556 #[test]
557 fn multiplication_works() {
558 let (val, deriv) = reverse_diff(|x| x.clone() * x, 3.0);
560 assert_eq!(val, 9.0);
561 assert_eq!(deriv, 6.0); }
563
564 #[test]
565 fn division_works() {
566 let (val, deriv) = reverse_diff(|x| x / 2.0, 6.0);
568 assert_eq!(val, 3.0);
569 assert_eq!(deriv, 0.5); }
571
572 #[test]
573 fn negation_works() {
574 let (val, deriv) = reverse_diff(|x| -x, 3.0);
576 assert_eq!(val, -3.0);
577 assert_eq!(deriv, -1.0);
578 }
579
580 #[test]
583 fn var_var_addition() {
584 let (val, deriv) = reverse_diff(|x| x.clone() + x, 3.0);
586 assert_eq!(val, 6.0);
587 assert_eq!(deriv, 2.0);
588 }
589
590 #[test]
591 fn var_var_subtraction() {
592 let (val, deriv) = reverse_diff(|x| x.clone() - x, 3.0);
594 assert_eq!(val, 0.0);
595 assert_eq!(deriv, 0.0);
596 }
597
598 #[test]
599 fn var_var_division() {
600 let (val, deriv) = reverse_diff(|x| x.clone() / x, 3.0);
603 assert_eq!(val, 1.0);
604 assert!((deriv - 0.0).abs() < 1e-10);
605 }
606
607 #[test]
610 fn var_add_scalar() {
611 let (val, deriv) = reverse_diff(|x| x + 10.0, 5.0);
612 assert_eq!(val, 15.0);
613 assert_eq!(deriv, 1.0);
614 }
615
616 #[test]
617 fn var_sub_scalar() {
618 let (val, deriv) = reverse_diff(|x| x - 3.0, 10.0);
619 assert_eq!(val, 7.0);
620 assert_eq!(deriv, 1.0);
621 }
622
623 #[test]
624 fn var_mul_scalar() {
625 let (val, deriv) = reverse_diff(|x| x * 3.0, 4.0);
626 assert_eq!(val, 12.0);
627 assert_eq!(deriv, 3.0);
628 }
629
630 #[test]
631 fn var_div_scalar() {
632 let (val, deriv) = reverse_diff(|x| x / 4.0, 12.0);
633 assert_eq!(val, 3.0);
634 assert_eq!(deriv, 0.25);
635 }
636
637 #[test]
640 fn recip_works() {
641 let (val, deriv) = reverse_diff(|x| x.recip(), 2.0);
643 assert_eq!(val, 0.5);
644 assert_eq!(deriv, -0.25); }
646
647 #[test]
648 fn exp_works() {
649 let (val, deriv) = reverse_diff(|x| x.exp(), 0.0);
651 assert_eq!(val, 1.0);
652 assert_eq!(deriv, 1.0); }
654
655 #[test]
656 fn exp_at_one() {
657 let (val, deriv) = reverse_diff(|x| x.exp(), 1.0);
659 let e = 1.0_f64.exp();
660 assert!((val - e).abs() < 1e-10);
661 assert!((deriv - e).abs() < 1e-10);
662 }
663
664 #[test]
665 fn sin_works() {
666 let (val, deriv) = reverse_diff(|x| x.sin(), 0.0);
668 assert_eq!(val, 0.0);
669 assert_eq!(deriv, 1.0); }
671
672 #[test]
673 fn cos_works() {
674 let (val, deriv) = reverse_diff(|x| x.cos(), 0.0);
676 assert_eq!(val, 1.0);
677 assert_eq!(deriv, 0.0); }
679
680 #[test]
681 fn ln_works() {
682 let (val, deriv) = reverse_diff(|x| x.ln(), 1.0);
684 assert_eq!(val, 0.0);
685 assert_eq!(deriv, 1.0); }
687
688 #[test]
689 fn ln_at_e() {
690 let e = 1.0_f64.exp();
692 let (val, deriv) = reverse_diff(|x| x.ln(), e);
693 assert!((val - 1.0).abs() < 1e-10);
694 assert!((deriv - 1.0 / e).abs() < 1e-10);
695 }
696
697 #[test]
698 fn sqrt_works() {
699 let (val, deriv) = reverse_diff(|x| x.sqrt(), 4.0);
701 assert_eq!(val, 2.0);
702 assert_eq!(deriv, 0.25); }
704
705 #[test]
708 fn fan_out_addition() {
709 let (val, deriv) = reverse_diff(|x| x.clone() + x.clone() + x, 2.0);
711 assert_eq!(val, 6.0);
712 assert_eq!(deriv, 3.0);
713 }
714
715 #[test]
716 fn fan_out_mixed() {
717 let (val, deriv) = reverse_diff(|x| x.clone() * x.clone() + x, 3.0);
720 assert_eq!(val, 12.0);
721 assert_eq!(deriv, 7.0);
722 }
723
724 #[test]
725 fn fan_out_product() {
726 let (val, deriv) = reverse_diff(|x| x.clone() * x.clone() * x, 2.0);
729 assert_eq!(val, 8.0);
730 assert_eq!(deriv, 12.0);
731 }
732
733 #[test]
736 fn polynomial() {
737 let (val, deriv) = reverse_diff(
740 |x| {
741 let x2 = x.clone() * x.clone();
742 let x3 = x2 * x.clone();
743 x3 - x * 2.0 + 1.0
744 },
745 2.0,
746 );
747 assert_eq!(val, 5.0);
748 assert_eq!(deriv, 10.0);
749 }
750
751 #[test]
752 fn quotient_rule() {
753 let (val, deriv) = reverse_diff(
757 |x| {
758 let num = x.clone() + 1.0;
759 let den = x + 2.0;
760 num / den
761 },
762 3.0,
763 );
764 assert_eq!(val, 0.8);
765 assert!((deriv - 0.04).abs() < 1e-10);
766 }
767
768 #[test]
769 fn chain_rule_product() {
770 let (val, deriv) = reverse_diff(|x| (x.clone() + 1.0) * (x - 1.0), 3.0);
773 assert_eq!(val, 8.0);
774 assert_eq!(deriv, 6.0);
775 }
776
777 #[test]
778 fn chain_rule_transcendental() {
779 let (val, deriv) = reverse_diff(|x| (x * 2.0).sin(), 0.0);
782 assert_eq!(val, 0.0);
783 assert_eq!(deriv, 2.0);
784 }
785
786 #[test]
787 fn exp_of_square() {
788 let (val, deriv) = reverse_diff(|x| (x.clone() * x).exp(), 1.0);
791 let e = 1.0_f64.exp();
792 assert!((val - e).abs() < 1e-10);
793 assert!((deriv - 2.0 * e).abs() < 1e-10);
794 }
795
796 #[test]
797 fn ln_of_square() {
798 let e = 1.0_f64.exp();
801 let (val, deriv) = reverse_diff(|x| (x.clone() * x).ln(), e);
802 assert!((val - 2.0).abs() < 1e-10);
803 assert!((deriv - 2.0 / e).abs() < 1e-10);
804 }
805
806 #[test]
807 fn sqrt_of_sum() {
808 let (val, deriv) = reverse_diff(|x| (x + 5.0).sqrt(), 4.0);
811 assert_eq!(val, 3.0);
812 assert!((deriv - 1.0 / 6.0).abs() < 1e-10);
813 }
814
815 #[test]
818 fn gradient_sum() {
819 let (val, grad) = reverse_gradient(|[x, y]| x + y, [3.0, 4.0]);
822 assert_eq!(val, 7.0);
823 assert_eq!(grad, [1.0, 1.0]);
824 }
825
826 #[test]
827 fn gradient_product() {
828 let (val, grad) = reverse_gradient(|[x, y]| x * y, [3.0, 4.0]);
831 assert_eq!(val, 12.0);
832 assert_eq!(grad, [4.0, 3.0]);
833 }
834
835 #[test]
836 fn gradient_mixed() {
837 let (val, grad) = reverse_gradient(|[x, y]| x.clone() * x.clone() + x * y, [3.0, 4.0]);
840 assert_eq!(val, 21.0);
841 assert_eq!(grad, [10.0, 3.0]);
842 }
843
844 #[test]
845 fn gradient_three_vars() {
846 let (val, grad) = reverse_gradient(
852 |[x, y, z]| x.clone() * y.clone() + y.clone() * z.clone() + z * x,
853 [1.0, 2.0, 3.0],
854 );
855 assert_eq!(val, 11.0);
856 assert_eq!(grad, [5.0, 4.0, 3.0]);
857 }
858
859 #[test]
860 fn gradient_with_transcendental() {
861 let (val, grad) = reverse_gradient(|[x, y]| x.sin() * y.cos(), [0.0, 0.0]);
866 assert_eq!(val, 0.0);
867 assert_eq!(grad, [1.0, 0.0]);
868 }
869
870 #[test]
873 fn reusable_function() {
874 let f = |x: Var<f64>| x.clone() * x.clone() + x * 2.0;
875
876 let (v1, d1) = reverse_diff(f, 3.0);
878 assert_eq!(v1, 15.0); assert_eq!(d1, 8.0); let (v2, d2) = reverse_diff(f, 5.0);
882 assert_eq!(v2, 35.0); assert_eq!(d2, 12.0); let (v3, d3) = reverse_diff(f, 0.0);
886 assert_eq!(v3, 0.0);
887 assert_eq!(d3, 2.0);
888 }
889
890 #[test]
893 fn matches_forward_mode_polynomial() {
894 use crate::dual::Dual;
896
897 let forward = {
898 let x = Dual::variable(2.0);
899 let y = x * x * x + Dual::constant(2.0) * x * x - x;
900 (y.value, y.deriv)
901 };
902
903 let reverse = reverse_diff(
904 |x| {
905 let x2 = x.clone() * x.clone();
906 let x3 = x2.clone() * x.clone();
907 x3 + x2 * 2.0 - x
908 },
909 2.0,
910 );
911
912 assert_eq!(forward.0, reverse.0);
913 assert!((forward.1 - reverse.1).abs() < 1e-10);
914 }
915
916 #[test]
917 fn matches_forward_mode_transcendental() {
918 use crate::dual::Dual;
920
921 let forward = {
922 let x = Dual::variable(1.0);
923 let y = x.sin().exp();
924 (y.value, y.deriv)
925 };
926
927 let reverse = reverse_diff(|x| x.sin().exp(), 1.0);
928
929 assert!((forward.0 - reverse.0).abs() < 1e-10);
930 assert!((forward.1 - reverse.1).abs() < 1e-10);
931 }
932}