mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Boilerplate code for the sum operator.
This commit is contained in:
@ -56,6 +56,7 @@ impl Tensor {
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Sum(node, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
@ -188,6 +189,9 @@ impl Tensor {
|
||||
Op::Broadcast(_arg) => {
|
||||
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
||||
}
|
||||
Op::Sum(_arg, _sum_dims) => {
|
||||
return Err(Error::BackwardNotSupported { op: "sum" })
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
|
Reference in New Issue
Block a user