mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Start adding index-add.
This commit is contained in:
@ -38,7 +38,9 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
@ -160,6 +162,7 @@ impl Tensor {
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
|
||||
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
|
||||
Op::IndexSelect(arg, indexes, dim) => {
|
||||
let dim = *dim;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user