Add the permute op (similar to pytorch). (#504)

* Add the permute op (similar to pytorch).

* Add the backprop for dimension permutation.
This commit is contained in:
Laurent Mazare
2023-08-18 16:30:53 +01:00
committed by GitHub
parent 4f1541526c
commit cb069d6063
7 changed files with 85 additions and 4 deletions

View File

@ -345,6 +345,16 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
let d3 = self.3.to_index(shape, op)?;
Ok(vec![d0, d1, d2, d3])
}
}
extract_dims!(dims0, 0, |_: &[usize]| (), ());
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));