From 78bac0ed3234beb0d79812ed773439831a0267e9 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 22:32:11 +0100 Subject: [PATCH] Add a couple operators. --- src/op.rs | 6 +++++- src/storage.rs | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/tensor.rs | 39 +++++++++++++++++++++++++++------- 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/src/op.rs b/src/op.rs index f7a4e576..240ecba3 100644 --- a/src/op.rs +++ b/src/op.rs @@ -2,13 +2,17 @@ use crate::Tensor; pub(crate) enum Op { Add(Tensor, Tensor), + Mul(Tensor, Tensor), + Sub(Tensor, Tensor), + Div(Tensor, Tensor), + #[allow(dead_code)] // add is currently unused. Affine { arg: Tensor, mul: f64, add: f64, }, - Mul(Tensor, Tensor), + Neg(Tensor), Sqr(Tensor), Sqrt(Tensor), // TODO: Support for custom ops. diff --git a/src/storage.rs b/src/storage.rs index a1b0f300..65f7c549 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -95,7 +95,10 @@ trait BinaryOp { } struct Add; +struct Div; struct Mul; +struct Sub; +struct Neg; struct Sqr; struct Sqrt; @@ -109,6 +112,16 @@ impl BinaryOp for Add { } } +impl BinaryOp for Sub { + const NAME: &'static str = "sub"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 - v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 - v2 + } +} + impl BinaryOp for Mul { const NAME: &'static str = "mul"; fn f32(v1: f32, v2: f32) -> f32 { @@ -119,6 +132,26 @@ impl BinaryOp for Mul { } } +impl BinaryOp for Div { + const NAME: &'static str = "div"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 / v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 / v2 + } +} + +impl UnaryOp for Neg { + const NAME: &'static str = "neg"; + fn f32(v1: f32) -> f32 { + -v1 + } + fn f64(v1: f64) -> f64 { + -v1 + } +} + impl UnaryOp for Sqr { const NAME: &'static str = "sqr"; fn f32(v1: f32) -> f32 { @@ -272,6 +305,16 @@ impl Storage { self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) } + pub(crate) fn sub_impl( + &self, + rhs: &Self, + shape: &Shape, + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result { + self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) + } + pub(crate) fn mul_impl( &self, rhs: &Self, @@ -282,6 +325,20 @@ impl Storage { self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) } + pub(crate) fn div_impl( + &self, + rhs: &Self, + shape: &Shape, + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result { + self.binary_impl::
(rhs, shape, lhs_stride, rhs_stride) + } + + pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + self.unary_impl::(shape, stride) + } + pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result { self.unary_impl::(shape, stride) } diff --git a/src/tensor.rs b/src/tensor.rs index 95b9a75c..bfe01adf 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -190,7 +190,10 @@ impl Tensor { // if this can create cycles in the compute graph. binary_op!(add, Add, add_impl); binary_op!(mul, Mul, mul_impl); + binary_op!(sub, Sub, sub_impl); + binary_op!(div, Div, div_impl); + unary_op!(neg, Neg, neg_impl); unary_op!(sqr, Sqr, sqr_impl); unary_op!(sqrt, Sqrt, sqrt_impl); pub fn to_scalar(&self) -> Result { @@ -320,7 +323,10 @@ impl Tensor { nodes } else if let Some(op) = &node.op { match op { - Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => { + Op::Add(lhs, rhs) + | Op::Mul(lhs, rhs) + | Op::Sub(lhs, rhs) + | Op::Div(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(rhs, nodes, already_seen); @@ -336,7 +342,7 @@ impl Tensor { nodes } } - Op::Sqr(node) | Op::Sqrt(node) => { + Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; nodes @@ -378,6 +384,12 @@ impl Tensor { let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); *rhs_sum_grad = rhs_sum_grad.add(&grad)?; } + Op::Sub(lhs, rhs) => { + let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + *lhs_sum_grad = lhs_sum_grad.add(&grad)?; + let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + *rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?; + } Op::Mul(lhs, rhs) => { let lhs_grad = grad.mul(rhs)?; let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); @@ -386,22 +398,33 @@ impl Tensor { let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } + Op::Div(lhs, rhs) => { + let lhs_grad = grad.div(rhs)?; + let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; + let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); *sum_grad = sum_grad.add(&arg_grad)? } + Op::Neg(arg) => { + let arg_grad = grad.neg()?; + let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + *sum_grad = sum_grad.add(&arg_grad)? + } Op::Sqr(arg) => { let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); *sum_grad = sum_grad.add(&arg_grad)? } - Op::Sqrt(_arg) => { - todo!() - // TODO: Add div to enable the following. - // let arg_grad = grad / (2 * arg) - // let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); - // *sum_grad = sum_grad.add(arg_grad)? + Op::Sqrt(arg) => { + let arg_grad = grad.div(arg)?.affine(0.5, 0.)?; + let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); + *sum_grad = sum_grad.add(&arg_grad)? } }; }