Do not backprop through argmin/argmax. (#865)

This commit is contained in:
Laurent Mazare
2023-09-15 23:15:40 +02:00
committed by GitHub
parent 3e49f8fce5
commit 635012d770
2 changed files with 9 additions and 2 deletions

View File

@ -98,7 +98,7 @@ impl Tensor {
| Op::Copy(node)
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, _, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
@ -112,6 +112,7 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
nodes
@ -521,6 +522,7 @@ impl Tensor {
}
}
#[derive(Debug)]
pub struct GradStore(HashMap<TensorId, Tensor>);
impl GradStore {

View File

@ -666,7 +666,12 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
dims[dim] = 1;
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
let op = match op {
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
}
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
};
let res = from_storage(storage, dims, op, false);
if keepdim {
Ok(res)