Boilerplate code for the sum operator.

This commit is contained in:
laurent
2023-06-25 09:35:17 +01:00
parent 7ccf27dda2
commit 3852a85af3
7 changed files with 61 additions and 1 deletions

View File

@ -56,6 +56,7 @@ impl Tensor {
}
Op::Reshape(node)
| Op::Broadcast(node)
| Op::Sum(node, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@ -188,6 +189,9 @@ impl Tensor {
Op::Broadcast(_arg) => {
return Err(Error::BackwardNotSupported { op: "broadcast" })
}
Op::Sum(_arg, _sum_dims) => {
return Err(Error::BackwardNotSupported { op: "sum" })
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?