Add some backprop test.

This commit is contained in:
laurent
2023-06-20 20:54:35 +01:00
parent c4c303b6f1
commit a419a9da72
2 changed files with 21 additions and 1 deletions

View File

@ -169,7 +169,7 @@ impl Tensor {
}
pub fn var<A: crate::device::NdArray>(array: A, device: Device) -> Result<Self> {
Self::new_impl(array, device, false)
Self::new_impl(array, device, true)
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@ -329,6 +329,7 @@ impl Tensor {
pub fn backward(&self) -> Result<HashMap<TensorId, Tensor>> {
let sorted_nodes = self.sorted_nodes();
println!("{}", sorted_nodes.len());
let mut grads = HashMap::new();
grads.insert(self.id, self.ones_like());
for node in sorted_nodes.iter() {