mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval. * Fix pyo3 bindings. * Black tweak. * Formatting. * Also update the pyo3-onnx formatting. * Apply black.
This commit is contained in:
@ -262,9 +262,19 @@ impl BatchNorm {
|
||||
let target_shape = target_shape.as_slice();
|
||||
|
||||
let x = x
|
||||
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
|
||||
.broadcast_sub(
|
||||
&self
|
||||
.running_mean
|
||||
.as_detached_tensor()
|
||||
.reshape(target_shape)?,
|
||||
)?
|
||||
.broadcast_div(
|
||||
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
|
||||
&(self
|
||||
.running_var
|
||||
.as_detached_tensor()
|
||||
.reshape(target_shape)?
|
||||
+ self.eps)?
|
||||
.sqrt()?,
|
||||
)?;
|
||||
|
||||
match &self.weight_and_bias {
|
||||
|
Reference in New Issue
Block a user