Allow the compiler to vectorize some broadcasting loops. (#194)

* Allow the compiler to vectorize some broadcasting loops.

* Improve the symmetrical broadcasting case.
This commit is contained in:
Laurent Mazare
2023-07-18 18:12:32 +02:00
committed by GitHub
parent 79a5b686d0
commit 6623c227d8

View File

@ -344,24 +344,18 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
ys ys
} }
Some(ob) => { Some(ob) => {
let mut i_in_block = 0;
let mut i_right_broadcast = 0;
let rhs = &rhs[ob.start..]; let rhs = &rhs[ob.start..];
lhs[o_l1..o_l2] let mut ys = lhs[o_l1..o_l2].to_vec();
.iter() for idx_l in 0..ob.left_broadcast {
.map(|&l| { let start = idx_l * ob.len * ob.right_broadcast;
let r = unsafe { rhs.get_unchecked(i_in_block) }; for (i, &r) in rhs.iter().enumerate() {
i_right_broadcast += 1; let start = start + i * ob.right_broadcast;
if i_right_broadcast >= ob.right_broadcast { for v in ys[start..start + ob.right_broadcast].iter_mut() {
i_in_block += 1; *v = f(*v, r)
i_right_broadcast = 0;
} }
if i_in_block >= ob.len { }
i_in_block = 0 }
} ys
f(l, *r)
})
.collect()
} }
None => lhs_l None => lhs_l
.strided_index() .strided_index()
@ -390,23 +384,17 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
} }
Some(ob) => { Some(ob) => {
let lhs = &lhs[ob.start..]; let lhs = &lhs[ob.start..];
let mut i_in_block = 0; let mut ys = rhs[o_r1..o_r2].to_vec();
let mut i_right_broadcast = 0; for idx_l in 0..ob.left_broadcast {
rhs[o_r1..o_r2] let start = idx_l * ob.len * ob.right_broadcast;
.iter() for (i, &l) in lhs.iter().enumerate() {
.map(|&r| { let start = start + i * ob.right_broadcast;
let l = unsafe { lhs.get_unchecked(i_in_block) }; for v in ys[start..start + ob.right_broadcast].iter_mut() {
i_right_broadcast += 1; *v = f(l, *v)
if i_right_broadcast >= ob.right_broadcast {
i_in_block += 1;
i_right_broadcast = 0;
} }
if i_in_block >= ob.len { }
i_in_block = 0 }
} ys
f(*l, r)
})
.collect()
} }
None => lhs_l None => lhs_l
.strided_index() .strided_index()