From 615196e7be243e21c96707cfb543ff79de07e461 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 09:59:52 +0100 Subject: [PATCH] Add more gradients. --- candle-core/src/backprop.rs | 21 +++++++++++++++------ candle-core/src/tensor.rs | 2 ++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index bc6740cf..ef15e65f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -106,9 +106,8 @@ impl Tensor { } let grad = grads.remove(node).unwrap(); // TODO: We should perform all these operations in place (or at least not track the - // whole graph). - // The only drawback would be if we wanted to support grad of grad but this is out of - // scope. + // whole graph). The only drawback would be if we wanted to support grad of grad but + // this is out of scope. if let Some(op) = node.op() { match op { Op::Add(lhs, rhs) => { @@ -139,8 +138,14 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::WhereCond(_pred, _t, _f) => { - return Err(Error::BackwardNotSupported { op: "where_cond" }) + Op::WhereCond(pred, t, f) => { + let zeros = grad.zeros_like()?; + let t_sum_grad = grads.or_insert(t)?; + let t_grad = pred.where_cond(&grad, &zeros)?; + *t_sum_grad = t_sum_grad.add(&t_grad)?; + let f_sum_grad = grads.or_insert(f)?; + let f_grad = pred.where_cond(&zeros, &grad)?; + *f_sum_grad = f_sum_grad.add(&f_grad)?; } Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) @@ -209,7 +214,11 @@ impl Tensor { Op::Softmax(_arg, _) => { return Err(Error::BackwardNotSupported { op: "softmax" }) } - Op::Reshape(_arg) => return Err(Error::BackwardNotSupported { op: "reshape" }), + Op::Reshape(arg) => { + let arg_grad = grad.reshape(arg.dims())?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }), Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }), Op::Sqr(arg) => { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4cff4efc..feb59d3c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -121,6 +121,7 @@ fn from_storage>( } impl Tensor { + // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn ones_impl>( shape: S, dtype: DType, @@ -144,6 +145,7 @@ impl Tensor { Tensor::ones(self.shape(), self.dtype(), &self.device()) } + // TODO: Maybe this should be a broadcast rather than actually creating the full tensor. fn zeros_impl>( shape: S, dtype: DType,