From b4aab7b95f1c1e420054d0eaad422b2d71c27755 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 11:37:42 +0100 Subject: [PATCH] Put more requirements on the withdtype trait. --- candle-core/src/cpu_backend.rs | 34 ++++++---------------------------- candle-core/src/dtype.rs | 2 +- 2 files changed, 7 insertions(+), 29 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index a47d7c18..1425d92f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -15,11 +15,7 @@ pub enum CpuStorage { } trait Map1 { - fn f( - &self, - vs: &[T], - layout: &Layout, - ) -> Result>; + fn f(&self, vs: &[T], layout: &Layout) -> Result>; fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { match vs { @@ -35,13 +31,7 @@ trait Map1 { type C = CpuStorage; trait Map2 { const OP: &'static str; - fn f( - &self, - v1: &[T], - l1: &Layout, - v2: &[T], - l2: &Layout, - ) -> Result>; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; fn map( &self, @@ -101,11 +91,7 @@ struct Sum<'a> { } impl<'a> Map1 for Sum<'a> { - fn f( - &self, - src: &[T], - src_layout: &Layout, - ) -> Result> { + fn f(&self, src: &[T], src_layout: &Layout) -> Result> { let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; for (unstr_index, src_index) in src_layout.strided_index().enumerate() { let mut dst_index = unstr_index; @@ -153,11 +139,7 @@ fn binary_map T>( struct Affine(f64, f64); impl Map1 for Affine { - fn f( - &self, - vs: &[T], - layout: &Layout, - ) -> Result> { + fn f(&self, vs: &[T], layout: &Layout) -> Result> { let mul = T::from_f64(self.0); let add = T::from_f64(self.1); Ok(unary_map(vs, layout, |v| v * mul + add)) @@ -292,11 +274,7 @@ impl Map2 for MatMul { } } -fn divide_by_sum_over_dim( - s: &mut [T], - shape: &Shape, - dim: usize, -) -> Result<()> { +fn divide_by_sum_over_dim(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> { // [self] stores data in a contiguous way starting at offset 0. let dims = shape.dims(); let elem_per_slice = dims[dim]; @@ -332,7 +310,7 @@ impl CpuStorage { } } - pub fn as_slice(&self) -> Result<&[D]> { + pub fn as_slice(&self) -> Result<&[D]> { D::cpu_storage_as_slice(self) } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 89655324..9a51635d 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -31,7 +31,7 @@ impl DType { } } -pub trait WithDType: Sized + Copy { +pub trait WithDType: Sized + Copy + num_traits::NumAssign + 'static { const DTYPE: DType; fn from_f64(v: f64) -> Self;