mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Handle tensor transfers between devices in the backprop.
This commit is contained in:
@ -17,6 +17,7 @@ pub(crate) enum Op {
|
||||
Neg(Tensor),
|
||||
Sqr(Tensor),
|
||||
Sqrt(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
// TODO: Support for custom ops.
|
||||
}
|
||||
|
@ -542,12 +542,17 @@ impl Tensor {
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
};
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDevice(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage,
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
op: None, // TODO: Have a proper op here.
|
||||
op,
|
||||
is_variable: self.is_variable,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
@ -596,7 +601,11 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
|
||||
Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
| Op::Neg(node) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
@ -692,6 +701,11 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::ToDevice(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let arg_grad = grad.to_device(&sum_grad.device())?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Transpose(arg, dim1, dim2) => {
|
||||
let arg_grad = grad.transpose(*dim1, *dim2)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user