diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index d85407f6..ca0352d8 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -13,6 +13,14 @@ pub enum CpuStorage { F64(Vec), } +fn unary_map T>(shape: &Shape, stride: &[usize], f: F) -> Vec { + if shape.is_contiguous(stride) { + (0..shape.elem_count()).map(f).collect() + } else { + StridedIndex::new(shape.dims(), stride).map(f).collect() + } +} + // This function maps over two strided index sequences. It supports broadcasting in case // `lhs_stride` or `rhs_stride` has a length shorter than `shape`. fn binary_map T>( @@ -22,7 +30,10 @@ fn binary_map T>( mut f: F, ) -> Vec { let dims = shape.dims(); - if dims.len() == lhs_stride.len() && dims.len() == rhs_stride.len() { + let broadcast_ldims = dims.len() - lhs_stride.len(); + let broadcast_rdims = dims.len() - rhs_stride.len(); + let elem_count = shape.elem_count(); + if broadcast_ldims == 0 && broadcast_rdims == 0 { if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { (0..shape.elem_count()).map(|i| f((i, i))).collect() } else { @@ -30,8 +41,34 @@ fn binary_map T>( let rhs_index = StridedIndex::new(dims, rhs_stride); lhs_index.zip(rhs_index).map(f).collect() } + } else if broadcast_rdims == 0 { + let mut res = Vec::new(); + res.reserve(elem_count); + let lhs_index: Vec<_> = StridedIndex::new(dims, lhs_stride).collect(); + let mut i = 0; + for rhs_i in StridedIndex::new(dims, rhs_stride) { + res.push(f((lhs_index[i], rhs_i))); + i += 1; + if i >= lhs_index.len() { + i = 0 + } + } + res + } else if broadcast_ldims == 0 { + let mut res = Vec::new(); + res.reserve(elem_count); + let rhs_index: Vec<_> = StridedIndex::new(dims, rhs_stride).collect(); + let mut i = 0; + for lhs_i in StridedIndex::new(dims, lhs_stride) { + res.push(f((lhs_i, rhs_index[i]))); + i += 1; + if i >= rhs_index.len() { + i = 0 + } + } + res } else { - todo!("implement broadcast") + panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}") } } @@ -61,22 +98,19 @@ impl CpuStorage { ) -> Result { match self { Self::U32(storage) => { - let index = StridedIndex::new(shape.dims(), stride); let mul = mul as u32; let add = add as u32; - let data = index.map(|i| storage[i] * mul + add).collect(); + let data = unary_map(shape, stride, |i| storage[i] * mul + add); Ok(Self::U32(data)) } Self::F32(storage) => { - let index = StridedIndex::new(shape.dims(), stride); let mul = mul as f32; let add = add as f32; - let data = index.map(|i| storage[i] * mul + add).collect(); + let data = unary_map(shape, stride, |i| storage[i] * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { - let index = StridedIndex::new(shape.dims(), stride); - let data = index.map(|i| storage[i] * mul + add).collect(); + let data = unary_map(shape, stride, |i| storage[i] * mul + add); Ok(Self::F64(data)) } }