mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the detach method.
This commit is contained in:
@ -524,6 +524,22 @@ impl Tensor {
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
// TODO: Currently this duplicates the storage, the PyTorch version would share the storage,
|
||||
// maybe we should do the same?
|
||||
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||
/// this new node.
|
||||
pub fn detach(&self) -> Result<Tensor> {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.try_clone()?,
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
op: None,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||
if self.device().same_id(device) {
|
||||
|
Reference in New Issue
Block a user