Do not backprop through argmin/argmax. (#865)

This commit is contained in:
Laurent Mazare
2023-09-15 23:15:40 +02:00
committed by GitHub
parent 3e49f8fce5
commit 635012d770
2 changed files with 9 additions and 2 deletions

View File

@ -98,7 +98,7 @@ impl Tensor {
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@ -112,6 +112,7 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@ -521,6 +522,7 @@ impl Tensor {
}
}
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {