Simplify the parameters used by sum and sum_keepdim. (#165)

This commit is contained in:
Laurent Mazare
2023-07-14 08:22:08 +01:00
committed by GitHub
parent 2bfa791336
commit a2f72edc0d
13 changed files with 179 additions and 98 deletions

View File

@ -256,6 +256,86 @@ impl Dim for D {
}
}
pub trait Dims: Sized {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let dims = self.to_indexes_internal(shape, op)?;
for (i, &dim) in dims.iter().enumerate() {
if dims[..i].contains(&dim) {
Err(Error::DuplicateDimIndex {
shape: shape.clone(),
dims: dims.clone(),
op,
})?
}
if dim >= shape.rank() {
Err(Error::DimOutOfRange {
shape: shape.clone(),
dim: dim as i32,
op,
})?
}
}
Ok(dims)
}
}
impl Dims for Vec<usize> {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(self)
}
}
impl<const N: usize> Dims for [usize; N] {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(self.to_vec())
}
}
impl Dims for &[usize] {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(self.to_vec())
}
}
impl Dims for () {
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
Ok(vec![])
}
}
impl<D: Dim + Sized> Dims for D {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let dim = self.to_index(shape, op)?;
Ok(vec![dim])
}
}
impl<D: Dim> Dims for (D,) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let dim = self.0.to_index(shape, op)?;
Ok(vec![dim])
}
}
impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
Ok(vec![d0, d1])
}
}
impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
let d0 = self.0.to_index(shape, op)?;
let d1 = self.1.to_index(shape, op)?;
let d2 = self.2.to_index(shape, op)?;
Ok(vec![d0, d1, d2])
}
}
#[cfg(test)]
mod tests {
use super::*;