diff --git a/src/op.rs b/src/op.rs index f330844e..a19e257f 100644 --- a/src/op.rs +++ b/src/op.rs @@ -17,6 +17,7 @@ pub(crate) enum Op { Neg(Tensor), Sqr(Tensor), Sqrt(Tensor), + Transpose(Tensor), // TODO: Support for custom ops. } diff --git a/src/tensor.rs b/src/tensor.rs index 549dafbe..7a2fbb98 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -461,6 +461,8 @@ impl Tensor { self.id } + /// Returns a tensor that is a transposed version of the input, the two last dimensions of the + /// input are swapped. pub fn t(&self) -> Result { let mut stride = self.stride().to_vec(); let n = stride.len(); @@ -474,13 +476,17 @@ impl Tensor { let mut dims = self.shape().dims().to_vec(); (dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]); (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); + let op = if self.track_op() { + Some(Op::Transpose(self.clone())) + } else { + None + }; let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.try_clone()?, shape: Shape::from(dims), stride, - // TODO The op should have a backward - op: None, + op, is_variable: false, }; Ok(Tensor(Arc::new(tensor_))) @@ -576,7 +582,7 @@ impl Tensor { nodes } } - Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => { + Op::Transpose(node) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; nodes @@ -672,6 +678,11 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } + Op::Transpose(arg) => { + let arg_grad = grad.t()?; + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&arg_grad)? + } }; } }