Add a couple operators.

This commit is contained in:
laurent
2023-06-20 22:32:11 +01:00
parent f1f372b13e
commit 78bac0ed32
3 changed files with 93 additions and 9 deletions

View File

@ -190,7 +190,10 @@ impl Tensor {
// if this can create cycles in the compute graph.
binary_op!(add, Add, add_impl);
binary_op!(mul, Mul, mul_impl);
binary_op!(sub, Sub, sub_impl);
binary_op!(div, Div, div_impl);
unary_op!(neg, Neg, neg_impl);
unary_op!(sqr, Sqr, sqr_impl);
unary_op!(sqrt, Sqrt, sqrt_impl);
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
@ -320,7 +323,10 @@ impl Tensor {
nodes
} else if let Some(op) = &node.op {
match op {
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(rhs, nodes, already_seen);
@ -336,7 +342,7 @@ impl Tensor {
nodes
}
}
Op::Sqr(node) | Op::Sqrt(node) => {
Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@ -378,6 +384,12 @@ impl Tensor {
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
}
Op::Sub(lhs, rhs) => {
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
}
Op::Mul(lhs, rhs) => {
let lhs_grad = grad.mul(rhs)?;
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
@ -386,22 +398,33 @@ impl Tensor {
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Div(lhs, rhs) => {
let lhs_grad = grad.div(rhs)?;
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::Affine { arg, mul, .. } => {
let arg_grad = grad.affine(*mul, 0.)?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Neg(arg) => {
let arg_grad = grad.neg()?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqr(arg) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Sqrt(_arg) => {
todo!()
// TODO: Add div to enable the following.
// let arg_grad = grad / (2 * arg)
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
// *sum_grad = sum_grad.add(arg_grad)?
Op::Sqrt(arg) => {
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
*sum_grad = sum_grad.add(&arg_grad)?
}
};
}