mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Do not backprop through argmin/argmax. (#865)
This commit is contained in:
@ -98,7 +98,7 @@ impl Tensor {
|
||||
| Op::Copy(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Cmp(node, _)
|
||||
| Op::Reduce(node, _, _)
|
||||
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
@ -112,6 +112,7 @@ impl Tensor {
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||
}
|
||||
} else {
|
||||
nodes
|
||||
@ -521,6 +522,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||
|
||||
impl GradStore {
|
||||
|
@ -666,7 +666,12 @@ impl Tensor {
|
||||
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = 1;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
|
||||
let op = match op {
|
||||
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
|
||||
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
|
||||
}
|
||||
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
|
||||
};
|
||||
let res = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(res)
|
||||
|
Reference in New Issue
Block a user