Add the backward pass for transpose.

This commit is contained in:
laurent
2023-06-23 08:43:05 +01:00
parent 3b550a56dc
commit 27d428af1a
2 changed files with 15 additions and 3 deletions

View File

@ -17,6 +17,7 @@ pub(crate) enum Op {
Neg(Tensor),
Sqr(Tensor),
Sqrt(Tensor),
Transpose(Tensor),
// TODO: Support for custom ops.
}

View File

@ -461,6 +461,8 @@ impl Tensor {
self.id
}
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped.
pub fn t(&self) -> Result<Tensor> {
let mut stride = self.stride().to_vec();
let n = stride.len();
@ -474,13 +476,17 @@ impl Tensor {
let mut dims = self.shape().dims().to_vec();
(dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]);
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
let op = if self.track_op() {
Some(Op::Transpose(self.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.try_clone()?,
shape: Shape::from(dims),
stride,
// TODO The op should have a backward
op: None,
op,
is_variable: false,
};
Ok(Tensor(Arc::new(tensor_)))
@ -576,7 +582,7 @@ impl Tensor {
nodes
}
}
Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
Op::Transpose(node) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
@ -672,6 +678,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Transpose(arg) => {
let arg_grad = grad.t()?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
};
}
}