mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Detach the tensors on batch-norm eval. (#1702)
* Detach the tensors on batch-norm eval. * Fix pyo3 bindings. * Black tweak. * Formatting. * Also update the pyo3-onnx formatting. * Apply black.
This commit is contained in:
@ -175,7 +175,7 @@ impl Tensor {
|
||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||
// derivatives but these are out of scope at the moment.
|
||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
||||
if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||
|
@ -1882,9 +1882,9 @@ impl Tensor {
|
||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||
///
|
||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
pub fn detach(&self) -> Tensor {
|
||||
if self.op.is_none() && !self.is_variable {
|
||||
Ok(self.clone())
|
||||
self.clone()
|
||||
} else {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1895,7 +1895,7 @@ impl Tensor {
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -107,6 +107,10 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn as_detached_tensor(&self) -> Tensor {
|
||||
self.0.detach()
|
||||
}
|
||||
|
||||
pub fn as_tensor(&self) -> &Tensor {
|
||||
&self.0
|
||||
}
|
||||
|
Reference in New Issue
Block a user