Add a couple operators.

This commit is contained in:
laurent
2023-06-20 22:32:11 +01:00
parent f1f372b13e
commit 78bac0ed32
3 changed files with 93 additions and 9 deletions

View File

@ -2,13 +2,17 @@ use crate::Tensor;
pub(crate) enum Op { pub(crate) enum Op {
Add(Tensor, Tensor), Add(Tensor, Tensor),
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
#[allow(dead_code)] // add is currently unused. #[allow(dead_code)] // add is currently unused.
Affine { Affine {
arg: Tensor, arg: Tensor,
mul: f64, mul: f64,
add: f64, add: f64,
}, },
Mul(Tensor, Tensor), Neg(Tensor),
Sqr(Tensor), Sqr(Tensor),
Sqrt(Tensor), Sqrt(Tensor),
// TODO: Support for custom ops. // TODO: Support for custom ops.

View File

@ -95,7 +95,10 @@ trait BinaryOp {
} }
struct Add; struct Add;
struct Div;
struct Mul; struct Mul;
struct Sub;
struct Neg;
struct Sqr; struct Sqr;
struct Sqrt; struct Sqrt;
@ -109,6 +112,16 @@ impl BinaryOp for Add {
} }
} }
impl BinaryOp for Sub {
const NAME: &'static str = "sub";
fn f32(v1: f32, v2: f32) -> f32 {
v1 - v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 - v2
}
}
impl BinaryOp for Mul { impl BinaryOp for Mul {
const NAME: &'static str = "mul"; const NAME: &'static str = "mul";
fn f32(v1: f32, v2: f32) -> f32 { fn f32(v1: f32, v2: f32) -> f32 {
@ -119,6 +132,26 @@ impl BinaryOp for Mul {
} }
} }
impl BinaryOp for Div {
const NAME: &'static str = "div";
fn f32(v1: f32, v2: f32) -> f32 {
v1 / v2
}
fn f64(v1: f64, v2: f64) -> f64 {
v1 / v2
}
}
impl UnaryOp for Neg {
const NAME: &'static str = "neg";
fn f32(v1: f32) -> f32 {
-v1
}
fn f64(v1: f64) -> f64 {
-v1
}
}
impl UnaryOp for Sqr { impl UnaryOp for Sqr {
const NAME: &'static str = "sqr"; const NAME: &'static str = "sqr";
fn f32(v1: f32) -> f32 { fn f32(v1: f32) -> f32 {
@ -272,6 +305,16 @@ impl Storage {
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride) self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
} }
pub(crate) fn sub_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Sub>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn mul_impl( pub(crate) fn mul_impl(
&self, &self,
rhs: &Self, rhs: &Self,
@ -282,6 +325,20 @@ impl Storage {
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride) self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
} }
pub(crate) fn div_impl(
&self,
rhs: &Self,
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
self.binary_impl::<Div>(rhs, shape, lhs_stride, rhs_stride)
}
pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Neg>(shape, stride)
}
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> { pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
self.unary_impl::<Sqr>(shape, stride) self.unary_impl::<Sqr>(shape, stride)
} }

View File

@ -190,7 +190,10 @@ impl Tensor {
// if this can create cycles in the compute graph. // if this can create cycles in the compute graph.
binary_op!(add, Add, add_impl); binary_op!(add, Add, add_impl);
binary_op!(mul, Mul, mul_impl); binary_op!(mul, Mul, mul_impl);
binary_op!(sub, Sub, sub_impl);
binary_op!(div, Div, div_impl);
unary_op!(neg, Neg, neg_impl);
unary_op!(sqr, Sqr, sqr_impl); unary_op!(sqr, Sqr, sqr_impl);
unary_op!(sqrt, Sqrt, sqrt_impl); unary_op!(sqrt, Sqrt, sqrt_impl);
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> { pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
@ -320,7 +323,10 @@ impl Tensor {
nodes nodes
} else if let Some(op) = &node.op { } else if let Some(op) = &node.op {
match op { match op {
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => { Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen); let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg; track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen); let (tg, nodes) = walk(rhs, nodes, already_seen);
@ -336,7 +342,7 @@ impl Tensor {
nodes nodes
} }
} }
Op::Sqr(node) | Op::Sqrt(node) => { Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen); let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg; track_grad |= tg;
nodes nodes
@ -378,6 +384,12 @@ impl Tensor {
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&grad)?; *rhs_sum_grad = rhs_sum_grad.add(&grad)?;
} }
Op::Sub(lhs, rhs) => {
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
}
Op::Mul(lhs, rhs) => { Op::Mul(lhs, rhs) => {
let lhs_grad = grad.mul(rhs)?; let lhs_grad = grad.mul(rhs)?;
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
@ -386,22 +398,33 @@ impl Tensor {
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
} }
Op::Div(lhs, rhs) => {
let lhs_grad = grad.div(rhs)?;
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Affine { arg, mul, .. } => { Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?; let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Neg(arg) => {
let arg_grad = grad.neg()?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqr(arg) => { Op::Sqr(arg) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?; let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like()); let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Sqrt(_arg) => { Op::Sqrt(arg) => {
todo!() let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
// TODO: Add div to enable the following. let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
// let arg_grad = grad / (2 * arg) *sum_grad = sum_grad.add(&arg_grad)?
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
// *sum_grad = sum_grad.add(arg_grad)?
} }
}; };
} }