mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Sketch the conv1d op.
This commit is contained in:
@ -33,7 +33,12 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Add(lhs, rhs)
|
||||
Op::Conv1D {
|
||||
arg: lhs,
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
@ -147,6 +152,7 @@ impl Tensor {
|
||||
let f_grad = pred.where_cond(&zeros, &grad)?;
|
||||
*f_sum_grad = f_sum_grad.add(&f_grad)?;
|
||||
}
|
||||
Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }),
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
}
|
||||
|
Reference in New Issue
Block a user