diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0d37774c..699bb452 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -326,6 +326,7 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> } (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { Some(ob) if ob.right_broadcast == 1 => { + let rhs = &rhs[ob.start..]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; @@ -345,10 +346,11 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> Some(ob) => { let mut i_in_block = 0; let mut i_right_broadcast = 0; + let rhs = &rhs[ob.start..]; lhs[o_l1..o_l2] .iter() .map(|&l| { - let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; + let r = unsafe { rhs.get_unchecked(i_in_block) }; i_right_broadcast += 1; if i_right_broadcast >= ob.right_broadcast { i_in_block += 1; @@ -369,6 +371,7 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> }, (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { Some(ob) if ob.right_broadcast == 1 => { + let lhs = &lhs[ob.start..]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; @@ -386,12 +389,13 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> ys } Some(ob) => { + let lhs = &lhs[ob.start..]; let mut i_in_block = 0; let mut i_right_broadcast = 0; rhs[o_r1..o_r2] .iter() .map(|&r| { - let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; + let l = unsafe { lhs.get_unchecked(i_in_block) }; i_right_broadcast += 1; if i_right_broadcast >= ob.right_broadcast { i_in_block += 1;