Detach the tensors on batch-norm eval. (#1702)

* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
This commit is contained in:
Laurent Mazare
2024-02-13 14:26:32 +01:00
committed by GitHub
parent 13c67226e6
commit ad73e93da2
14 changed files with 117 additions and 27 deletions

View File

@ -262,9 +262,19 @@ impl BatchNorm {
let target_shape = target_shape.as_slice();
let x = x
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
.broadcast_sub(
&self
.running_mean
.as_detached_tensor()
.reshape(target_shape)?,
)?
.broadcast_div(
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
&(self
.running_var
.as_detached_tensor()
.reshape(target_shape)?
+ self.eps)?
.sqrt()?,
)?;
match &self.weight_and_bias {