mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add where_cond and properly apply the causal mask.
This commit is contained in:
@ -24,6 +24,15 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t3, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
@ -161,6 +170,9 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(_pred, _t, _f) => {
|
||||
return Err(Error::BackwardNotSupported { op: "where_cond" })
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
}
|
||||
|
Reference in New Issue
Block a user