mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Handle tensor transfers between devices in the backprop.
This commit is contained in:
@ -17,6 +17,7 @@ pub(crate) enum Op {
|
|||||||
Neg(Tensor),
|
Neg(Tensor),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
|
ToDevice(Tensor),
|
||||||
Transpose(Tensor, usize, usize),
|
Transpose(Tensor, usize, usize),
|
||||||
// TODO: Support for custom ops.
|
// TODO: Support for custom ops.
|
||||||
}
|
}
|
||||||
|
@ -542,12 +542,17 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
};
|
};
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::ToDevice(self.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage,
|
storage,
|
||||||
shape: self.shape.clone(),
|
shape: self.shape.clone(),
|
||||||
stride: self.stride.clone(),
|
stride: self.stride.clone(),
|
||||||
op: None, // TODO: Have a proper op here.
|
op,
|
||||||
is_variable: self.is_variable,
|
is_variable: self.is_variable,
|
||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
@ -596,7 +601,11 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
|
Op::ToDevice(node)
|
||||||
|
| Op::Transpose(node, _, _)
|
||||||
|
| Op::Sqr(node)
|
||||||
|
| Op::Sqrt(node)
|
||||||
|
| Op::Neg(node) => {
|
||||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
@ -692,6 +701,11 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
|
Op::ToDevice(arg) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
let arg_grad = grad.to_device(&sum_grad.device())?;
|
||||||
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
|
}
|
||||||
Op::Transpose(arg, dim1, dim2) => {
|
Op::Transpose(arg, dim1, dim2) => {
|
||||||
let arg_grad = grad.transpose(*dim1, *dim2)?;
|
let arg_grad = grad.transpose(*dim1, *dim2)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
Reference in New Issue
Block a user