Start adding index-add.

This commit is contained in:
laurent
2023-07-21 20:12:48 +01:00
parent 5cc843550d
commit 27174a82aa
8 changed files with 97 additions and 3 deletions

View File

@ -38,7 +38,9 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
Op::IndexAdd(t1, t2, t3, _)
| Op::CustomOp3(t1, t2, t3, _)
| Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
@ -160,6 +162,7 @@ impl Tensor {
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => Err(Error::BackwardNotSupported { op: "conv1d" })?,
Op::IndexAdd { .. } => Err(Error::BackwardNotSupported { op: "index-add" })?,
Op::IndexSelect(arg, indexes, dim) => {
let dim = *dim;
let sum_grad = grads.or_insert(arg)?;