From 635012d770a75033081008a22044804d277fafa8 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 15 Sep 2023 23:15:40 +0200 Subject: [PATCH] Do not backprop through argmin/argmax. (#865) --- candle-core/src/backprop.rs | 4 +++- candle-core/src/tensor.rs | 7 ++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index b930a9f4..9c8f685f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -98,7 +98,7 @@ impl Tensor { | Op::Copy(node) | Op::Broadcast(node) | Op::Cmp(node, _) - | Op::Reduce(node, _, _) + | Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -112,6 +112,7 @@ impl Tensor { track_grad |= tg; nodes } + Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes, } } else { nodes @@ -521,6 +522,7 @@ impl Tensor { } } +#[derive(Debug)] pub struct GradStore(HashMap); impl GradStore { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 4388bf77..61f576cf 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -666,7 +666,12 @@ impl Tensor { let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let mut dims = self.dims().to_vec(); dims[dim] = 1; - let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())); + let op = match op { + ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => { + BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec())) + } + ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(), + }; let res = from_storage(storage, dims, op, false); if keepdim { Ok(res)