diff --git a/src/op.rs b/src/op.rs index 7e2740f2..f7a4e576 100644 --- a/src/op.rs +++ b/src/op.rs @@ -2,6 +2,12 @@ use crate::Tensor; pub(crate) enum Op { Add(Tensor, Tensor), + #[allow(dead_code)] // add is currently unused. + Affine { + arg: Tensor, + mul: f64, + add: f64, + }, Mul(Tensor, Tensor), Sqr(Tensor), Sqrt(Tensor), diff --git a/src/storage.rs b/src/storage.rs index 013397d5..a1b0f300 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -172,6 +172,32 @@ impl Storage { } } + pub(crate) fn affine_impl( + &self, + shape: &Shape, + stride: &[usize], + mul: f64, + add: f64, + ) -> Result { + // TODO: Different code path for the contiguous case? + match self { + Storage::Cpu(storage) => match storage { + CpuStorage::F32(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let mul = mul as f32; + let add = add as f32; + let data = index.map(|i| storage[i] * mul + add).collect(); + Ok(Storage::Cpu(CpuStorage::F32(data))) + } + CpuStorage::F64(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| storage[i] * mul + add).collect(); + Ok(Storage::Cpu(CpuStorage::F64(data))) + } + }, + } + } + fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { // TODO: Different code path for the contiguous case? match self { diff --git a/src/tensor.rs b/src/tensor.rs index 69222af8..95b9a75c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -209,6 +209,26 @@ impl Tensor { } } + pub fn affine(&self, mul: f64, add: f64) -> Result { + let shape = self.shape(); + let storage = self + .storage + .affine_impl(self.shape(), self.stride(), mul, add)?; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op: Some(Op::Affine { + arg: self.clone(), + mul, + add, + }), + is_variable: false, + }; + Ok(Self(Arc::new(tensor_))) + } + pub(crate) fn strided_index(&self) -> crate::storage::StridedIndex { crate::storage::StridedIndex::new(self.dims(), self.stride()) } @@ -307,6 +327,15 @@ impl Tensor { track_grad |= tg; nodes } + Op::Affine { arg, mul, .. } => { + if *mul == 0. { + nodes + } else { + let (tg, nodes) = walk(arg, nodes, already_seen); + track_grad |= tg; + nodes + } + } Op::Sqr(node) | Op::Sqrt(node) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; @@ -357,16 +386,19 @@ impl Tensor { let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Sqr(_arg) => { - todo!() - // TODO: Add scaling by a constant to enable the following. - // let arg_grad = 2 * arg * grad; - // let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); - // *sum_grad = sum_grad.add(arg_grad)? + Op::Affine { arg, mul, .. } => { + let arg_grad = grad.affine(*mul, 0.)?; + let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + *sum_grad = sum_grad.add(&arg_grad)? + } + Op::Sqr(arg) => { + let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; + let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + *sum_grad = sum_grad.add(&arg_grad)? } Op::Sqrt(_arg) => { todo!() - // TODO: Add scaling by a constant and divide to enable the following. + // TODO: Add div to enable the following. // let arg_grad = grad / (2 * arg) // let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); // *sum_grad = sum_grad.add(arg_grad)?