mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the gradient for reduce-sum. (#162)
* Add the gradient for reduce-sum. * And add the gradient for the broadcast ops. * Add some backprop tests. * Add some linear regression example.
This commit is contained in:
@ -179,11 +179,33 @@ impl Tensor {
|
||||
start_idx += len;
|
||||
}
|
||||
}
|
||||
Op::Broadcast(_arg) => {
|
||||
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
||||
Op::Broadcast(arg) => {
|
||||
let arg_dims = arg.dims();
|
||||
let node_dims = node.dims();
|
||||
// The number of dims that have been inserted on the left.
|
||||
let left_dims = node_dims.len() - arg_dims.len();
|
||||
let mut sum_dims: Vec<usize> = (0..left_dims).collect();
|
||||
for (dim, (node_dim, arg_dim)) in node_dims[left_dims..]
|
||||
.iter()
|
||||
.zip(arg_dims.iter())
|
||||
.enumerate()
|
||||
{
|
||||
if node_dim != arg_dim {
|
||||
sum_dims.push(dim + left_dims)
|
||||
}
|
||||
}
|
||||
|
||||
let mut arg_grad = grad.sum(sum_dims.as_slice())?;
|
||||
// sum_dims has increasing values.
|
||||
for &dim in sum_dims.iter().rev() {
|
||||
arg_grad = arg_grad.squeeze(dim)?
|
||||
}
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&arg_grad)?
|
||||
}
|
||||
Op::Sum(_arg, _sum_dims) => {
|
||||
return Err(Error::BackwardNotSupported { op: "sum" })
|
||||
Op::Sum(arg, _sum_dims) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&grad)?
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user