diff --git a/candle-core/src/indexer.rs b/candle-core/src/indexer.rs index 22719527..725ba732 100644 --- a/candle-core/src/indexer.rs +++ b/candle-core/src/indexer.rs @@ -129,114 +129,23 @@ where } } -impl IndexOp<(A,)> for Tensor -where - A: Into, -{ - fn i(&self, index: (A,)) -> Result { - let idx_a = index.0.into(); - self.index(&[idx_a]) - } -} - -impl IndexOp<(A, B)> for Tensor -where - A: Into, - B: Into, -{ - fn i(&self, index: (A, B)) -> Result { - let idx_a = index.0.into(); - let idx_b = index.1.into(); - self.index(&[idx_a, idx_b]) - } -} - -impl IndexOp<(A, B, C)> for Tensor -where - A: Into, - B: Into, - C: Into, -{ - fn i(&self, index: (A, B, C)) -> Result { - let idx_a = index.0.into(); - let idx_b = index.1.into(); - let idx_c = index.2.into(); - self.index(&[idx_a, idx_b, idx_c]) - } -} - -impl IndexOp<(A, B, C, D)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, -{ - fn i(&self, index: (A, B, C, D)) -> Result { - let idx_a = index.0.into(); - let idx_b = index.1.into(); - let idx_c = index.2.into(); - let idx_d = index.3.into(); - self.index(&[idx_a, idx_b, idx_c, idx_d]) - } -} - -impl IndexOp<(A, B, C, D, E)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, -{ - fn i(&self, index: (A, B, C, D, E)) -> Result { - let idx_a = index.0.into(); - let idx_b = index.1.into(); - let idx_c = index.2.into(); - let idx_d = index.3.into(); - let idx_e = index.4.into(); - self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e]) - } -} - -impl IndexOp<(A, B, C, D, E, F)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, - F: Into, -{ - fn i(&self, index: (A, B, C, D, E, F)) -> Result { - let idx_a = index.0.into(); - let idx_b = index.1.into(); - let idx_c = index.2.into(); - let idx_d = index.3.into(); - let idx_e = index.4.into(); - let idx_f = index.5.into(); - self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f]) - } -} - -impl IndexOp<(A, B, C, D, E, F, G)> for Tensor -where - A: Into, - B: Into, - C: Into, - D: Into, - E: Into, - F: Into, - G: Into, -{ - fn i(&self, index: (A, B, C, D, E, F, G)) -> Result { - let idx_a = index.0.into(); - let idx_b = index.1.into(); - let idx_c = index.2.into(); - let idx_d = index.3.into(); - let idx_e = index.4.into(); - let idx_f = index.5.into(); - let idx_g = index.6.into(); - self.index(&[idx_a, idx_b, idx_c, idx_d, idx_e, idx_f, idx_g]) - } +macro_rules! index_op_tuple { + ($($t:ident),+) => { + #[allow(non_snake_case)] + impl<$($t),*> IndexOp<($($t,)*)> for Tensor + where + $($t: Into,)* + { + fn i(&self, ($($t,)*): ($($t,)*)) -> Result { + self.index(&[$($t.into(),)*]) + } + } + }; } +index_op_tuple!(A); +index_op_tuple!(A, B); +index_op_tuple!(A, B, C); +index_op_tuple!(A, B, C, D); +index_op_tuple!(A, B, C, D, E); +index_op_tuple!(A, B, C, D, E, F); +index_op_tuple!(A, B, C, D, E, F, G);