Sketch the conv1d op.

This commit is contained in:
laurent
2023-07-04 10:52:34 +01:00
parent e6b01d0c18
commit 3aac1047fe
7 changed files with 97 additions and 1 deletions

View File

@ -33,7 +33,12 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::Add(lhs, rhs)
Op::Conv1D {
arg: lhs,
kernel: rhs,
..
}
| Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
@ -147,6 +152,7 @@ impl Tensor {
let f_grad = pred.where_cond(&zeros, &grad)?;
*f_sum_grad = f_sum_grad.add(&f_grad)?;
}
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}