mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add transposition around arbitrary axis.
This commit is contained in:
@ -17,7 +17,7 @@ pub(crate) enum Op {
|
|||||||
Neg(Tensor),
|
Neg(Tensor),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
Transpose(Tensor),
|
Transpose(Tensor, usize, usize),
|
||||||
// TODO: Support for custom ops.
|
// TODO: Support for custom ops.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -464,20 +464,34 @@ impl Tensor {
|
|||||||
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
|
||||||
/// input are swapped.
|
/// input are swapped.
|
||||||
pub fn t(&self) -> Result<Tensor> {
|
pub fn t(&self) -> Result<Tensor> {
|
||||||
let mut stride = self.stride().to_vec();
|
let rank = self.rank();
|
||||||
let n = stride.len();
|
if rank < 2 {
|
||||||
if n < 2 {
|
|
||||||
return Err(Error::UnexpectedNumberOfDims {
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
expected: 2,
|
expected: 2,
|
||||||
got: n,
|
got: rank,
|
||||||
shape: self.shape().clone(),
|
shape: self.shape().clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
self.transpose(rank - 2, rank - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a tensor that is a transposed version of the input, the given dimensions are
|
||||||
|
/// swapped.
|
||||||
|
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
|
||||||
|
let rank = self.rank();
|
||||||
|
if rank <= dim1 || rank <= dim2 {
|
||||||
|
return Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: usize::max(dim1, dim2),
|
||||||
|
got: rank,
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let mut stride = self.stride().to_vec();
|
||||||
let mut dims = self.shape().dims().to_vec();
|
let mut dims = self.shape().dims().to_vec();
|
||||||
(dims[n - 2], dims[n - 1]) = (dims[n - 1], dims[n - 2]);
|
dims.swap(dim1, dim2);
|
||||||
(stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]);
|
stride.swap(dim1, dim2);
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Transpose(self.clone()))
|
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -582,7 +596,7 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::Transpose(node) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
|
Op::Transpose(node, _, _) | Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
|
||||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
@ -678,8 +692,8 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Transpose(arg) => {
|
Op::Transpose(arg, dim1, dim2) => {
|
||||||
let arg_grad = grad.t()?;
|
let arg_grad = grad.transpose(*dim1, *dim2)?;
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user