Put more requirements on the withdtype trait.

This commit is contained in:
laurent
2023-06-29 11:37:42 +01:00
parent c8fc9da737
commit b4aab7b95f
2 changed files with 7 additions and 29 deletions

View File

@ -15,11 +15,7 @@ pub enum CpuStorage {
}
trait Map1 {
fn f<T: WithDType + Copy + num_traits::NumAssign>(
&self,
vs: &[T],
layout: &Layout,
) -> Result<Vec<T>>;
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
match vs {
@ -35,13 +31,7 @@ trait Map1 {
type C = CpuStorage;
trait Map2 {
const OP: &'static str;
fn f<T: WithDType + Copy + num_traits::Num + 'static>(
&self,
v1: &[T],
l1: &Layout,
v2: &[T],
l2: &Layout,
) -> Result<Vec<T>>;
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
fn map(
&self,
@ -101,11 +91,7 @@ struct Sum<'a> {
}
impl<'a> Map1 for Sum<'a> {
fn f<T: WithDType + Copy + num_traits::NumAssign>(
&self,
src: &[T],
src_layout: &Layout,
) -> Result<Vec<T>> {
fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
let mut dst_index = unstr_index;
@ -153,11 +139,7 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
struct Affine(f64, f64);
impl Map1 for Affine {
fn f<T: WithDType + Copy + num_traits::NumAssign>(
&self,
vs: &[T],
layout: &Layout,
) -> Result<Vec<T>> {
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
let mul = T::from_f64(self.0);
let add = T::from_f64(self.1);
Ok(unary_map(vs, layout, |v| v * mul + add))
@ -292,11 +274,7 @@ impl Map2 for MatMul {
}
}
fn divide_by_sum_over_dim<T: WithDType + num_traits::NumAssign>(
s: &mut [T],
shape: &Shape,
dim: usize,
) -> Result<()> {
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {
// [self] stores data in a contiguous way starting at offset 0.
let dims = shape.dims();
let elem_per_slice = dims[dim];
@ -332,7 +310,7 @@ impl CpuStorage {
}
}
pub fn as_slice<D: crate::WithDType>(&self) -> Result<&[D]> {
pub fn as_slice<D: WithDType>(&self) -> Result<&[D]> {
D::cpu_storage_as_slice(self)
}