diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 06241115..c336dfef 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -162,6 +162,64 @@ fn binary_map T>( .zip(rhs[o_r1..o_r2].iter()) .map(|(&l, &r)| f(l, r)) .collect(), + (Some((o_l1, o_l2)), None) => { + // TODO: Maybe we want to avoid going through the layout twice. + match rhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + lhs[o_l1..o_l2] + .iter() + .map(|&l| { + let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(l, *r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } + (None, Some((o_r1, o_r2))) => { + // TODO: Maybe we want to avoid going through the layout twice. + match lhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + rhs[o_r1..o_r2] + .iter() + .map(|&r| { + let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(*l, r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } _ => lhs_l .strided_index() .zip(rhs_l.strided_index()) diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 22ad53cf..95dc9667 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -179,4 +179,60 @@ impl Layout { } } } + + // Returns the contiguous offsets with broadcast if applicable. + pub(crate) fn offsets_b(&self) -> Option { + let mut left_broadcast = 1; + let mut right_broadcast = 1; + let strides = self.stride(); + let dims = self.dims(); + let mut start_cont = 0; + let mut end_cont = dims.len(); + for (&s, &d) in strides.iter().zip(dims.iter()) { + if s != 0 { + break; + } + start_cont += 1; + left_broadcast *= d; + } + if start_cont == dims.len() { + return Some(ContiguousOffsetsWithBroadcast { + start: self.start_offset, + len: 1, + left_broadcast, + right_broadcast: 1, + }); + } + for (&s, &d) in strides.iter().zip(dims.iter()).rev() { + if s != 0 { + break; + } + end_cont -= 1; + right_broadcast *= d; + } + // Check that the inner dims are contiguous + let strides = &strides[start_cont..end_cont]; + let dims = &dims[start_cont..end_cont]; + let mut len = 1; + for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() { + if stride != len { + return None; + } + len *= dim; + } + Some(ContiguousOffsetsWithBroadcast { + start: self.start_offset, + len, + left_broadcast, + right_broadcast, + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ContiguousOffsetsWithBroadcast { + pub start: usize, + pub len: usize, + pub left_broadcast: usize, + pub right_broadcast: usize, } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index bb5ecb01..a1f942d4 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -46,7 +46,7 @@ mod dtype; mod dummy_cuda_backend; mod error; mod indexer; -mod layout; +pub mod layout; #[cfg(feature = "mkl")] mod mkl; pub mod npy; diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 455b903c..eb6a736f 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -8,7 +8,6 @@ pub struct StridedIndex<'a> { multi_index: Vec, dims: &'a [usize], stride: &'a [usize], - start_offset: usize, } impl<'a> StridedIndex<'a> { @@ -25,7 +24,6 @@ impl<'a> StridedIndex<'a> { multi_index: vec![0; dims.len()], dims, stride, - start_offset, } } @@ -43,24 +41,26 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) => storage_index, }; let mut updated = false; - for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + let mut next_storage_index = storage_index; + for ((multi_i, max_i), stride_i) in self + .multi_index + .iter_mut() + .zip(self.dims.iter()) + .zip(self.stride.iter()) + .rev() + { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; updated = true; + next_storage_index += stride_i; break; } else { + next_storage_index -= *multi_i * stride_i; *multi_i = 0 } } self.next_storage_index = if updated { - let next_storage_index = self - .multi_index - .iter() - .zip(self.stride.iter()) - .map(|(&x, &y)| x * y) - .sum::() - + self.start_offset; Some(next_storage_index) } else { None