Add transposition around arbitrary axis.

This commit is contained in:
laurent
2023-06-23 08:51:13 +01:00
parent 27d428af1a
commit 3f79d81b6f
2 changed files with 25 additions and 11 deletions

View File

@ -17,7 +17,7 @@ pub(crate) enum Op {
Neg(Tensor),
Sqr(Tensor),
Sqrt(Tensor),
Transpose(Tensor),
Transpose(Tensor, usize, usize),
// TODO: Support for custom ops.
}

View File

@ -464,20 +464,34 @@ impl Tensor {
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped.
pub fn t(&self) -> Result<Tensor> {
let mut stride = self.stride().to_vec();
let n = stride.len();
if n < 2 {
let rank = self.rank();
if rank < 2 {
return Err(Error::UnexpectedNumberOfDims {
expected: 2,
got: n,
got: rank,
shape: self.shape().clone(),
});
}
self.transpose(rank - 2, rank - 1)
}
/// Returns a tensor that is a transposed version of the input, the given dimensions are
/// swapped.
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
let rank = self.rank();
if rank <= dim1 || rank <= dim2 {
return Err(Error::UnexpectedNumberOfDims {
expected: usize::max(dim1, dim2),
got: rank,
shape: self.shape().clone(),
});
}
let mut stride = self.stride().to_vec();
let mut dims = self.shape().dims().to_vec();
(dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]);
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
dims.swap(dim1, dim2);
stride.swap(dim1, dim2);
let op = if self.track_op() {
Some(Op::Transpose(self.clone()))
Some(Op::Transpose(self.clone(), dim1, dim2))
} else {
None
};
@ -582,7 +596,7 @@ impl Tensor {
nodes
}
}
Op::Transpose(node) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@ -678,8 +692,8 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Transpose(arg) => {
let arg_grad = grad.t()?;
Op::Transpose(arg, dim1, dim2) => {
let arg_grad = grad.transpose(*dim1, *dim2)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}