mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Simplify the parameters used by sum and sum_keepdim. (#165)
This commit is contained in:
@ -256,6 +256,86 @@ impl Dim for D {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Dims: Sized {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>>;
|
||||
|
||||
fn to_indexes(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let dims = self.to_indexes_internal(shape, op)?;
|
||||
for (i, &dim) in dims.iter().enumerate() {
|
||||
if dims[..i].contains(&dim) {
|
||||
Err(Error::DuplicateDimIndex {
|
||||
shape: shape.clone(),
|
||||
dims: dims.clone(),
|
||||
op,
|
||||
})?
|
||||
}
|
||||
if dim >= shape.rank() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: shape.clone(),
|
||||
dim: dim as i32,
|
||||
op,
|
||||
})?
|
||||
}
|
||||
}
|
||||
Ok(dims)
|
||||
}
|
||||
}
|
||||
|
||||
impl Dims for Vec<usize> {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize> Dims for [usize; N] {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(self.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Dims for &[usize] {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(self.to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
impl Dims for () {
|
||||
fn to_indexes_internal(self, _: &Shape, _: &'static str) -> Result<Vec<usize>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Dim + Sized> Dims for D {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let dim = self.to_index(shape, op)?;
|
||||
Ok(vec![dim])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: Dim> Dims for (D,) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let dim = self.0.to_index(shape, op)?;
|
||||
Ok(vec![dim])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim> Dims for (D1, D2) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1])
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
Reference in New Issue
Block a user