Detach the tensors on batch-norm eval. (#1702)

* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
This commit is contained in:
Laurent Mazare
2024-02-13 14:26:32 +01:00
committed by GitHub
parent 13c67226e6
commit ad73e93da2
14 changed files with 117 additions and 27 deletions

View File

@ -175,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? };
let grad = if do_not_detach { grad } else { grad.detach() };
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {

View File

@ -1882,9 +1882,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> {
pub fn detach(&self) -> Tensor {
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
self.clone()
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
@ -1895,7 +1895,7 @@ impl Tensor {
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
Tensor(Arc::new(tensor_))
}
}

View File

@ -107,6 +107,10 @@ impl Var {
Ok(Self(inner))
}
pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}
pub fn as_tensor(&self) -> &Tensor {
&self.0
}