diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 65632ef3..b9b665d7 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -113,6 +113,15 @@ impl CpuStorage { .collect(); Ok(Self::F64(data)) } + (Self::U32(lhs), Self::U32(rhs)) => { + let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); + let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); + let data = lhs_index + .zip(rhs_index) + .map(|(lhs_i, rhs_i)| B::u32(lhs[lhs_i], rhs[rhs_i])) + .collect(); + Ok(Self::U32(data)) + } _ => { // This should be covered by the dtype check above. Err(Error::DTypeMismatchBinaryOp { diff --git a/src/op.rs b/src/op.rs index ae8fa80d..40be9f4c 100644 --- a/src/op.rs +++ b/src/op.rs @@ -49,6 +49,7 @@ pub(crate) trait BinaryOp { const KERNEL_F64: &'static str; fn f32(v1: f32, v2: f32) -> f32; fn f64(v1: f64, v2: f64) -> f64; + fn u32(v1: u32, v2: u32) -> u32; } pub(crate) struct Add; @@ -70,6 +71,9 @@ impl BinaryOp for Add { fn f64(v1: f64, v2: f64) -> f64 { v1 + v2 } + fn u32(v1: u32, v2: u32) -> u32 { + v1 + v2 + } } impl BinaryOp for Sub { @@ -82,6 +86,9 @@ impl BinaryOp for Sub { fn f64(v1: f64, v2: f64) -> f64 { v1 - v2 } + fn u32(v1: u32, v2: u32) -> u32 { + v1 - v2 + } } impl BinaryOp for Mul { @@ -94,6 +101,9 @@ impl BinaryOp for Mul { fn f64(v1: f64, v2: f64) -> f64 { v1 * v2 } + fn u32(v1: u32, v2: u32) -> u32 { + v1 * v2 + } } impl BinaryOp for Div { @@ -106,6 +116,9 @@ impl BinaryOp for Div { fn f64(v1: f64, v2: f64) -> f64 { v1 / v2 } + fn u32(v1: u32, v2: u32) -> u32 { + v1 / v2 + } } impl UnaryOp for Neg {