diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f83bb5e6..9f0c8602 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -81,22 +81,23 @@ fn unary_map U>(vs: &[T], layout: &Layout, mut // This function maps over two strided index sequences. fn binary_map T>( - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { - let shape = lhs_layout.shape(); - if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() { - (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() - } else { - let lhs_index = lhs_layout.strided_index(); - let rhs_index = rhs_layout.strided_index(); - lhs_index - .zip(rhs_index) + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] + .iter() + .zip(rhs[o_r1..o_r2].iter()) + .map(|(&l, &r)| f(l, r)) + .collect(), + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect() + .collect(), } } @@ -151,15 +152,17 @@ fn matmul( lhs: &[T], rhs: &[T], (b, m, n, k): (usize, usize, usize, usize), - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result> { + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; - let lhs_stride = lhs_layout.stride(); - let rhs_stride = rhs_layout.stride(); + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); let rank = lhs_stride.len(); let lhs_cs = lhs_stride[rank - 1]; let lhs_rs = lhs_stride[rank - 2]; @@ -509,28 +512,28 @@ impl CpuStorage { pub(crate) fn binary_impl( &self, rhs: &Self, - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::bf16); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16); Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f16); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16); Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f32); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f64); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::u32); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { @@ -622,20 +625,20 @@ impl CpuStorage { &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { match (self, rhs) { (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F16(dst)) } (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F32(dst)) } (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F64(dst)) } _ => Err(Error::DTypeMismatchBinaryOp {