Add the backward pass for transpose.

This commit is contained in:
laurent
2023-06-23 08:43:05 +01:00
parent 3b550a56dc
commit 27d428af1a
2 changed files with 15 additions and 3 deletions

View File

@ -17,6 +17,7 @@ pub(crate) enum Op {
Neg(Tensor), Neg(Tensor),
Sqr(Tensor), Sqr(Tensor),
Sqrt(Tensor), Sqrt(Tensor),
Transpose(Tensor),
// TODO: Support for custom ops. // TODO: Support for custom ops.
} }

View File

@ -461,6 +461,8 @@ impl Tensor {
self.id self.id
} }
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped.
pub fn t(&self) -> Result<Tensor> { pub fn t(&self) -> Result<Tensor> {
let mut stride = self.stride().to_vec(); let mut stride = self.stride().to_vec();
let n = stride.len(); let n = stride.len();
@ -474,13 +476,17 @@ impl Tensor {
let mut dims = self.shape().dims().to_vec(); let mut dims = self.shape().dims().to_vec();
(dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]); (dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]);
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
let op = if self.track_op() {
Some(Op::Transpose(self.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: self.storage.try_clone()?, storage: self.storage.try_clone()?,
shape: Shape::from(dims), shape: Shape::from(dims),
stride, stride,
// TODO The op should have a backward op,
op: None,
is_variable: false, is_variable: false,
}; };
Ok(Tensor(Arc::new(tensor_))) Ok(Tensor(Arc::new(tensor_)))
@ -576,7 +582,7 @@ impl Tensor {
nodes nodes
} }
} }
Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => { Op::Transpose(node) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen); let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg; track_grad |= tg;
nodes nodes
@ -672,6 +678,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)? *sum_grad = sum_grad.add(&arg_grad)?
} }
Op::Transpose(arg) => {
let arg_grad = grad.t()?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
}; };
} }
} }