Handle tensor transfers between devices in the backprop.

This commit is contained in:
laurent
2023-06-23 08:55:34 +01:00
parent 3f79d81b6f
commit 3e7cb18d7f
2 changed files with 17 additions and 2 deletions

View File

@ -17,6 +17,7 @@ pub(crate) enum Op {
Neg(Tensor),
Sqr(Tensor),
Sqrt(Tensor),
ToDevice(Tensor),
Transpose(Tensor, usize, usize),
// TODO: Support for custom ops.
}

View File

@ -542,12 +542,17 @@ impl Tensor {
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
};
let op = if self.track_op() {
Some(Op::ToDevice(self.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage,
shape: self.shape.clone(),
stride: self.stride.clone(),
op: None, // TODO: Have a proper op here.
op,
is_variable: self.is_variable,
};
Ok(Tensor(Arc::new(tensor_)))
@ -596,7 +601,11 @@ impl Tensor {
nodes
}
}
Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Sqr(node)
| Op::Sqrt(node)
| Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@ -692,6 +701,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::ToDevice(arg) => {
let sum_grad = grads.or_insert(arg)?;
let arg_grad = grad.to_device(&sum_grad.device())?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Transpose(arg, dim1, dim2) => {
let arg_grad = grad.transpose(*dim1, *dim2)?;
let sum_grad = grads.or_insert(arg)?;