mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the detach method.
This commit is contained in:
@ -524,6 +524,22 @@ impl Tensor {
|
|||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Currently this duplicates the storage, the PyTorch version would share the storage,
|
||||||
|
// maybe we should do the same?
|
||||||
|
/// Returns a new tensor detached from the current graph, gradient are not propagated through
|
||||||
|
/// this new node.
|
||||||
|
pub fn detach(&self) -> Result<Tensor> {
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage: self.storage.try_clone()?,
|
||||||
|
shape: self.shape.clone(),
|
||||||
|
stride: self.stride.clone(),
|
||||||
|
op: None,
|
||||||
|
is_variable: false,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
/// If the target device is the same as the tensor device, only a shallow copy is performed.
|
||||||
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
pub fn to_device(&self, device: &Device) -> Result<Tensor> {
|
||||||
if self.device().same_id(device) {
|
if self.device().same_id(device) {
|
||||||
|
Reference in New Issue
Block a user