diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 40dd292e..494c41ee 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -13,11 +13,18 @@ pub enum CpuStorage { F64(Vec), } -fn unary_map T>(shape: &Shape, stride: &[usize], f: F) -> Vec { +fn unary_map T>( + shape: &Shape, + stride: &[usize], + vs: &[T], + mut f: F, +) -> Vec { if shape.is_contiguous(stride) { - (0..shape.elem_count()).map(f).collect() + vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() } else { - StridedIndex::new(shape.dims(), stride).map(f).collect() + StridedIndex::new(shape.dims(), stride) + .map(|i| f(vs[i])) + .collect() } } @@ -109,37 +116,35 @@ impl CpuStorage { Self::U32(storage) => { let mul = mul as u32; let add = add as u32; - let data = unary_map(shape, stride, |i| storage[i] * mul + add); + let data = unary_map(shape, stride, storage, |v| v * mul + add); Ok(Self::U32(data)) } Self::F32(storage) => { let mul = mul as f32; let add = add as f32; - let data = unary_map(shape, stride, |i| storage[i] * mul + add); + let data = unary_map(shape, stride, storage, |v| v * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(shape, stride, |i| storage[i] * mul + add); + let data = unary_map(shape, stride, storage, |v| v * mul + add); Ok(Self::F64(data)) } } } pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { - // TODO: Different code path for the contiguous case? match self { Self::F32(storage) => { - let index = StridedIndex::new(shape.dims(), stride); - let data = index.map(|i| B::f32(storage[i])).collect(); + let data = unary_map(shape, stride, storage, B::f32); Ok(Self::F32(data)) } Self::F64(storage) => { - let index = StridedIndex::new(shape.dims(), stride); - let data = index.map(|i| B::f64(storage[i])).collect(); + let data = unary_map(shape, stride, storage, B::f64); Ok(Self::F64(data)) } - Self::U32(_storage) => { - todo!("No unary for u32 because of neg, sqrt") + Self::U32(storage) => { + let data = unary_map(shape, stride, storage, B::u32); + Ok(Self::U32(data)) } } } diff --git a/src/op.rs b/src/op.rs index 0f0f5ee4..b8198ca8 100644 --- a/src/op.rs +++ b/src/op.rs @@ -39,6 +39,7 @@ pub(crate) trait UnaryOp { const KERNEL_F64: &'static str; fn f32(v1: f32) -> f32; fn f64(v1: f64) -> f64; + fn u32(v1: u32) -> u32; } pub(crate) trait BinaryOp { @@ -134,6 +135,9 @@ impl UnaryOp for Neg { fn f64(v1: f64) -> f64 { -v1 } + fn u32(_: u32) -> u32 { + 0 + } const KERNEL_F32: &'static str = "uneg_f32"; const KERNEL_F64: &'static str = "uneg_f64"; } @@ -146,6 +150,9 @@ impl UnaryOp for Sqr { fn f64(v1: f64) -> f64 { v1 * v1 } + fn u32(v: u32) -> u32 { + v * v + } const KERNEL_F32: &'static str = "usqr_f32"; const KERNEL_F64: &'static str = "usqr_f64"; } @@ -158,6 +165,9 @@ impl UnaryOp for Sqrt { fn f64(v1: f64) -> f64 { v1.sqrt() } + fn u32(v: u32) -> u32 { + (v as f64).sqrt() as u32 + } const KERNEL_F32: &'static str = "usqrt_f32"; const KERNEL_F64: &'static str = "usqrt_f64"; } @@ -184,6 +194,9 @@ impl UnaryOp for Gelu { fn f64(v1: f64) -> f64 { gelu_f64(v1) } + fn u32(v1: u32) -> u32 { + gelu_f64(v1 as f64) as u32 + } const KERNEL_F32: &'static str = "gelu_f32"; const KERNEL_F64: &'static str = "gelu_f64"; }