mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch). * Add the backprop for dimension permutation.
This commit is contained in:
@ -96,6 +96,7 @@ impl Tensor {
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
@ -403,6 +404,15 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Permute(arg, dims) => {
|
||||
let mut inv_dims = vec![0; dims.len()];
|
||||
for (i, &dim_idx) in dims.iter().enumerate() {
|
||||
inv_dims[dim_idx] = i
|
||||
}
|
||||
let arg_grad = grad.permute(inv_dims)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user