Macroify the repeating bits. (#129)

This commit is contained in:
Laurent Mazare
2023-07-10 19:44:06 +01:00
committed by GitHub
parent 23849cb6e6
commit 2be09dbb1d

View File

@ -129,114 +129,23 @@ where
}
}
impl<A> IndexOp<(A,)> for Tensor
where
A: Into<TensorIndexer>,
{
fn i(&self, index: (A,)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
self.index(&[idx_a])
}
}
impl<A, B> IndexOp<(A, B)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
{
fn i(&self, index: (A, B)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
self.index(&[idx_a, idx_b])
}
}
impl<A, B, C> IndexOp<(A, B, C)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
self.index(&[idx_a, idx_b, idx_c])
}
}
impl<A, B, C, D> IndexOp<(A, B, C, D)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
self.index(&[idx_a, idx_b, idx_c, idx_d])
}
}
impl<A, B, C, D, E> IndexOp<(A, B, C, D, E)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
E: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e])
}
}
impl<A, B, C, D, E, F> IndexOp<(A, B, C, D, E, F)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
E: Into<TensorIndexer>,
F: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E, F)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
let idx_f = index.5.into();
self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f])
}
}
impl<A, B, C, D, E, F, G> IndexOp<(A, B, C, D, E, F, G)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
C: Into<TensorIndexer>,
D: Into<TensorIndexer>,
E: Into<TensorIndexer>,
F: Into<TensorIndexer>,
G: Into<TensorIndexer>,
{
fn i(&self, index: (A, B, C, D, E, F, G)) -> Result<Tensor, Error> {
let idx_a = index.0.into();
let idx_b = index.1.into();
let idx_c = index.2.into();
let idx_d = index.3.into();
let idx_e = index.4.into();
let idx_f = index.5.into();
let idx_g = index.6.into();
self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g])
}
macro_rules! index_op_tuple {
($($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);