From a419a9da727ba715e8d8e930763fd3d46472f925 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 20:54:35 +0100 Subject: [PATCH] Add some backprop test. --- src/tensor.rs | 3 ++- tests/grad_tests.rs | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 tests/grad_tests.rs diff --git a/src/tensor.rs b/src/tensor.rs index 7f3fb07f..69222af8 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -169,7 +169,7 @@ impl Tensor { } pub fn var(array: A, device: Device) -> Result { - Self::new_impl(array, device, false) + Self::new_impl(array, device, true) } pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { @@ -329,6 +329,7 @@ impl Tensor { pub fn backward(&self) -> Result> { let sorted_nodes = self.sorted_nodes(); + println!("{}", sorted_nodes.len()); let mut grads = HashMap::new(); grads.insert(self.id, self.ones_like()); for node in sorted_nodes.iter() { diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs new file mode 100644 index 00000000..d8149bea --- /dev/null +++ b/tests/grad_tests.rs @@ -0,0 +1,19 @@ +use anyhow::{Context, Result}; +use candle::{Device, Tensor}; + +#[test] +fn simple_grad() -> Result<()> { + let x = Tensor::var(&[3f32, 1., 4.], Device::Cpu)?; + let five = Tensor::new(&[5f32, 5., 5.], Device::Cpu)?; + let x_times_five = x.mul(&five)?; + let x2 = x.mul(&x)?; + let y = x2.add(&x_times_five)?.add(&five)?; + let grads = y.backward()?; + let grad_x = grads.get(&x.id()).context("no grad for x")?; + assert_eq!(x.to_vec1::()?, [3., 1., 4.]); + // y = x^2 + 5.x + 5 + assert_eq!(y.to_vec1::()?, [29., 11., 41.]); + // dy/dx = 2.x + 5 + assert_eq!(grad_x.to_vec1::()?, [11., 7., 13.]); + Ok(()) +}