From c4c303b6f1638cbffd15400cecca2c2e04573488 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 20:33:44 +0100 Subject: [PATCH] Add some very basic backprop. --- src/device.rs | 19 +++++++++++ src/op.rs | 1 - src/shape.rs | 6 ++++ src/tensor.rs | 91 ++++++++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 112 insertions(+), 5 deletions(-) diff --git a/src/device.rs b/src/device.rs index 0964e83f..d7b724d1 100644 --- a/src/device.rs +++ b/src/device.rs @@ -56,6 +56,25 @@ impl NdArray for &[[S; N]; } impl Device { + pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Storage { + match self { + Device::Cpu => { + let elem_count = shape.elem_count(); + let storage = match dtype { + DType::F32 => { + let data = vec![1f32; elem_count]; + CpuStorage::F32(data) + } + DType::F64 => { + let data = vec![1f64; elem_count]; + CpuStorage::F64(data) + } + }; + Storage::Cpu(storage) + } + } + } + pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage { match self { Device::Cpu => { diff --git a/src/op.rs b/src/op.rs index e991c120..7e2740f2 100644 --- a/src/op.rs +++ b/src/op.rs @@ -1,6 +1,5 @@ use crate::Tensor; -#[allow(dead_code)] pub(crate) enum Op { Add(Tensor, Tensor), Mul(Tensor, Tensor), diff --git a/src/shape.rs b/src/shape.rs index 36a48276..d626aee6 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -33,6 +33,12 @@ impl From<&[usize]> for Shape { } } +impl From<&Shape> for Shape { + fn from(shape: &Shape) -> Self { + Self(shape.0.to_vec()) + } +} + impl From<()> for Shape { fn from(_: ()) -> Self { Self(vec![]) diff --git a/src/tensor.rs b/src/tensor.rs index 8a827783..7f3fb07f 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,4 +1,5 @@ use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; +use std::collections::HashMap; use std::sync::Arc; /// Unique identifier for tensors. @@ -14,7 +15,6 @@ impl TensorId { } } -#[allow(dead_code)] pub struct Tensor_ { id: TensorId, storage: Storage, @@ -81,6 +81,38 @@ macro_rules! binary_op { } impl Tensor { + fn ones_impl>( + shape: S, + dtype: DType, + device: Device, + is_variable: bool, + ) -> Self { + let shape = shape.into(); + let storage = device.ones(&shape, dtype); + let stride = shape.stride_contiguous(); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape, + stride, + op: None, + is_variable, + }; + Self(Arc::new(tensor_)) + } + + pub fn ones>(shape: S, dtype: DType, device: Device) -> Self { + Self::ones_impl(shape, dtype, device, false) + } + + pub fn ones_var>(shape: S, dtype: DType, device: Device) -> Self { + Self::ones_impl(shape, dtype, device, true) + } + + pub fn ones_like(&self) -> Self { + Tensor::ones(self.shape(), self.dtype(), self.device()) + } + fn zeros_impl>( shape: S, dtype: DType, @@ -109,6 +141,10 @@ impl Tensor { Self::zeros_impl(shape, dtype, device, true) } + pub fn zeros_like(&self) -> Self { + Tensor::zeros(self.shape(), self.dtype(), self.device()) + } + pub fn new_impl( array: A, device: Device, @@ -246,9 +282,7 @@ impl Tensor { /// elements having dependencies on the latter ones, e.g. the first element if any is the /// argument. /// This assumes that the op graph is a DAG. - pub fn sorted_nodes(&self) -> Vec<&Tensor> { - use std::collections::HashMap; - + fn sorted_nodes(&self) -> Vec<&Tensor> { // The vec of sorted nodes is passed as an owned value rather than a mutable reference // to get around some lifetime limitations. fn walk<'a>( @@ -292,4 +326,53 @@ impl Tensor { nodes.reverse(); nodes } + + pub fn backward(&self) -> Result> { + let sorted_nodes = self.sorted_nodes(); + let mut grads = HashMap::new(); + grads.insert(self.id, self.ones_like()); + for node in sorted_nodes.iter() { + if node.is_variable { + continue; + } + let grad = grads.remove(&node.id).unwrap(); + // TODO: We should perform all these operations in place (or at least not track the + // whole graph). + // The only drawback would be if we wanted to support grad of grad but this is out of + // scope. + if let Some(op) = &node.op { + match op { + Op::Add(lhs, rhs) => { + let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + *rhs_sum_grad = rhs_sum_grad.add(&grad)?; + } + Op::Mul(lhs, rhs) => { + let lhs_grad = grad.mul(rhs)?; + let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + let rhs_grad = grad.mul(lhs)?; + let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } + Op::Sqr(_arg) => { + todo!() + // TODO: Add scaling by a constant to enable the following. + // let arg_grad = 2 * arg * grad; + // let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + // *sum_grad = sum_grad.add(arg_grad)? + } + Op::Sqrt(_arg) => { + todo!() + // TODO: Add scaling by a constant and divide to enable the following. + // let arg_grad = grad / (2 * arg) + // let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + // *sum_grad = sum_grad.add(arg_grad)? + } + }; + } + } + Ok(grads) + } }