Merge pull request #34 from LaurentMazare/simpler-dtype-trait

Put more requirements on the withdtype trait.
This commit is contained in:
Laurent Mazare
2023-06-29 11:41:17 +01:00
committed by GitHub
2 changed files with 7 additions and 29 deletions

View File

@ -15,11 +15,7 @@ pub enum CpuStorage {
} }
trait Map1 { trait Map1 {
fn f<T: WithDType + Copy + num_traits::NumAssign>( fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
&self,
vs: &[T],
layout: &Layout,
) -> Result<Vec<T>>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> { fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs { match vs {
@ -35,13 +31,7 @@ trait Map1 {
type C = CpuStorage; type C = CpuStorage;
trait Map2 { trait Map2 {
const OP: &'static str; const OP: &'static str;
fn f<T: WithDType + Copy + num_traits::Num + 'static>( fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
&self,
v1: &[T],
l1: &Layout,
v2: &[T],
l2: &Layout,
) -> Result<Vec<T>>;
fn map( fn map(
&self, &self,
@ -101,11 +91,7 @@ struct Sum<'a> {
} }
impl<'a> Map1 for Sum<'a> { impl<'a> Map1 for Sum<'a> {
fn f<T: WithDType + Copy + num_traits::NumAssign>( fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
&self,
src: &[T],
src_layout: &Layout,
) -> Result<Vec<T>> {
let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
for (unstr_index, src_index) in src_layout.strided_index().enumerate() { for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
let mut dst_index = unstr_index; let mut dst_index = unstr_index;
@ -153,11 +139,7 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
struct Affine(f64, f64); struct Affine(f64, f64);
impl Map1 for Affine { impl Map1 for Affine {
fn f<T: WithDType + Copy + num_traits::NumAssign>( fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
&self,
vs: &[T],
layout: &Layout,
) -> Result<Vec<T>> {
let mul = T::from_f64(self.0); let mul = T::from_f64(self.0);
let add = T::from_f64(self.1); let add = T::from_f64(self.1);
Ok(unary_map(vs, layout, |v| v * mul + add)) Ok(unary_map(vs, layout, |v| v * mul + add))
@ -292,11 +274,7 @@ impl Map2 for MatMul {
} }
} }
fn divide_by_sum_over_dim<T: WithDType + num_traits::NumAssign>( fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
s: &mut [T],
shape: &Shape,
dim: usize,
) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0. // [self] stores data in a contiguous way starting at offset 0.
let dims = shape.dims(); let dims = shape.dims();
let elem_per_slice = dims[dim]; let elem_per_slice = dims[dim];
@ -332,7 +310,7 @@ impl CpuStorage {
} }
} }
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> { pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self) D::cpu_storage_as_slice(self)
} }

View File

@ -31,7 +31,7 @@ impl DType {
} }
} }
pub trait WithDType: Sized + Copy { pub trait WithDType: Sized + Copy + num_traits::NumAssign + 'static {
const DTYPE: DType; const DTYPE: DType;
fn from_f64(v: f64) -> Self; fn from_f64(v: f64) -> Self;