diff --git a/src/op.rs b/src/op.rs index 34d7ae76..6e909a35 100644 --- a/src/op.rs +++ b/src/op.rs @@ -17,6 +17,7 @@ pub(crate) enum Op { Neg(Tensor), Sqr(Tensor), Sqrt(Tensor), + ToDevice(Tensor), Transpose(Tensor, usize, usize), // TODO: Support for custom ops. } diff --git a/src/tensor.rs b/src/tensor.rs index 2cef4cd5..11d69dec 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -542,12 +542,17 @@ impl Tensor { } (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), }; + let op = if self.track_op() { + Some(Op::ToDevice(self.clone())) + } else { + None + }; let tensor_ = Tensor_ { id: TensorId::new(), storage, shape: self.shape.clone(), stride: self.stride.clone(), - op: None, // TODO: Have a proper op here. + op, is_variable: self.is_variable, }; Ok(Tensor(Arc::new(tensor_))) @@ -596,7 +601,11 @@ impl Tensor { nodes } } - Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => { + Op::ToDevice(node) + | Op::Transpose(node, _, _) + | Op::Sqr(node) + | Op::Sqrt(node) + | Op::Neg(node) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; nodes @@ -692,6 +701,11 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } + Op::ToDevice(arg) => { + let sum_grad = grads.or_insert(arg)?; + let arg_grad = grad.to_device(&sum_grad.device())?; + *sum_grad = sum_grad.add(&arg_grad)? + } Op::Transpose(arg, dim1, dim2) => { let arg_grad = grad.transpose(*dim1, *dim2)?; let sum_grad = grads.or_insert(arg)?;