mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Do not backprop through argmin/argmax. (#865)
This commit is contained in:
@ -98,7 +98,7 @@ impl Tensor {
|
|||||||
| Op::Copy(node)
|
| Op::Copy(node)
|
||||||
| Op::Broadcast(node)
|
| Op::Broadcast(node)
|
||||||
| Op::Cmp(node, _)
|
| Op::Cmp(node, _)
|
||||||
| Op::Reduce(node, _, _)
|
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
|
||||||
| Op::ToDType(node)
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
@ -112,6 +112,7 @@ impl Tensor {
|
|||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
|
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nodes
|
nodes
|
||||||
@ -521,6 +522,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct GradStore(HashMap<TensorId, Tensor>);
|
pub struct GradStore(HashMap<TensorId, Tensor>);
|
||||||
|
|
||||||
impl GradStore {
|
impl GradStore {
|
||||||
|
@ -666,7 +666,12 @@ impl Tensor {
|
|||||||
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
dims[dim] = 1;
|
dims[dim] = 1;
|
||||||
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
|
let op = match op {
|
||||||
|
ReduceOp::Sum | ReduceOp::Min | ReduceOp::Max => {
|
||||||
|
BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()))
|
||||||
|
}
|
||||||
|
ReduceOp::ArgMin | ReduceOp::ArgMax => BackpropOp::none(),
|
||||||
|
};
|
||||||
let res = from_storage(storage, dims, op, false);
|
let res = from_storage(storage, dims, op, false);
|
||||||
if keepdim {
|
if keepdim {
|
||||||
Ok(res)
|
Ok(res)
|
||||||
|
Reference in New Issue
Block a user