Add where_cond and properly apply the causal mask.

This commit is contained in:
laurent
2023-06-25 21:08:03 +01:00
parent 25bcad290e
commit 117f014b55
8 changed files with 168 additions and 24 deletions

View File

@ -24,6 +24,15 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t3, nodes, already_seen);
track_grad |= tg;
nodes
}
Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
@ -161,6 +170,9 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
}
Op::WhereCond(_pred, _t, _f) => {
return Err(Error::BackwardNotSupported { op: "where_cond" })
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}