diff --git a/src/op.rs b/src/op.rs index a19e257f..34d7ae76 100644 --- a/src/op.rs +++ b/src/op.rs @@ -17,7 +17,7 @@ pub(crate) enum Op { Neg(Tensor), Sqr(Tensor), Sqrt(Tensor), - Transpose(Tensor), + Transpose(Tensor, usize, usize), // TODO: Support for custom ops. } diff --git a/src/tensor.rs b/src/tensor.rs index 7a2fbb98..2cef4cd5 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -464,20 +464,34 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the two last dimensions of the /// input are swapped. pub fn t(&self) -> Result { - let mut stride = self.stride().to_vec(); - let n = stride.len(); - if n < 2 { + let rank = self.rank(); + if rank < 2 { return Err(Error::UnexpectedNumberOfDims { expected: 2, - got: n, + got: rank, shape: self.shape().clone(), }); } + self.transpose(rank - 2, rank - 1) + } + + /// Returns a tensor that is a transposed version of the input, the given dimensions are + /// swapped. + pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { + let rank = self.rank(); + if rank <= dim1 || rank <= dim2 { + return Err(Error::UnexpectedNumberOfDims { + expected: usize::max(dim1, dim2), + got: rank, + shape: self.shape().clone(), + }); + } + let mut stride = self.stride().to_vec(); let mut dims = self.shape().dims().to_vec(); - (dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]); - (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); + dims.swap(dim1, dim2); + stride.swap(dim1, dim2); let op = if self.track_op() { - Some(Op::Transpose(self.clone())) + Some(Op::Transpose(self.clone(), dim1, dim2)) } else { None }; @@ -582,7 +596,7 @@ impl Tensor { nodes } } - Op::Transpose(node) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => { + Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; nodes @@ -678,8 +692,8 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } - Op::Transpose(arg) => { - let arg_grad = grad.t()?; + Op::Transpose(arg, dim1, dim2) => { + let arg_grad = grad.transpose(*dim1, *dim2)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? }