mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Properly use the offset when broadcasting on a narrow slice. (#193)
This commit is contained in:
@ -326,6 +326,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
|||||||
}
|
}
|
||||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let rhs = &rhs[ob.start..];
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
@ -345,10 +346,11 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
|||||||
Some(ob) => {
|
Some(ob) => {
|
||||||
let mut i_in_block = 0;
|
let mut i_in_block = 0;
|
||||||
let mut i_right_broadcast = 0;
|
let mut i_right_broadcast = 0;
|
||||||
|
let rhs = &rhs[ob.start..];
|
||||||
lhs[o_l1..o_l2]
|
lhs[o_l1..o_l2]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&l| {
|
.map(|&l| {
|
||||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
let r = unsafe { rhs.get_unchecked(i_in_block) };
|
||||||
i_right_broadcast += 1;
|
i_right_broadcast += 1;
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
i_in_block += 1;
|
i_in_block += 1;
|
||||||
@ -369,6 +371,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
|||||||
},
|
},
|
||||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let lhs = &lhs[ob.start..];
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
@ -386,12 +389,13 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
|||||||
ys
|
ys
|
||||||
}
|
}
|
||||||
Some(ob) => {
|
Some(ob) => {
|
||||||
|
let lhs = &lhs[ob.start..];
|
||||||
let mut i_in_block = 0;
|
let mut i_in_block = 0;
|
||||||
let mut i_right_broadcast = 0;
|
let mut i_right_broadcast = 0;
|
||||||
rhs[o_r1..o_r2]
|
rhs[o_r1..o_r2]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|&r| {
|
.map(|&r| {
|
||||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
let l = unsafe { lhs.get_unchecked(i_in_block) };
|
||||||
i_right_broadcast += 1;
|
i_right_broadcast += 1;
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
i_in_block += 1;
|
i_in_block += 1;
|
||||||
|
Reference in New Issue
Block a user