From 6623c227d8779cdc6b5d7f4ee74b521bb5f1b21d Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 18 Jul 2023 18:12:32 +0200 Subject: [PATCH] Allow the compiler to vectorize some broadcasting loops. (#194) * Allow the compiler to vectorize some broadcasting loops. * Improve the symmetrical broadcasting case. --- candle-core/src/cpu_backend.rs | 52 +++++++++++++--------------------- 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 699bb452..d2c727f3 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -344,24 +344,18 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> ys } Some(ob) => { - let mut i_in_block = 0; - let mut i_right_broadcast = 0; let rhs = &rhs[ob.start..]; - lhs[o_l1..o_l2] - .iter() - .map(|&l| { - let r = unsafe { rhs.get_unchecked(i_in_block) }; - i_right_broadcast += 1; - if i_right_broadcast >= ob.right_broadcast { - i_in_block += 1; - i_right_broadcast = 0; + let mut ys = lhs[o_l1..o_l2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &r) in rhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(*v, r) } - if i_in_block >= ob.len { - i_in_block = 0 - } - f(l, *r) - }) - .collect() + } + } + ys } None => lhs_l .strided_index() @@ -390,23 +384,17 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> } Some(ob) => { let lhs = &lhs[ob.start..]; - let mut i_in_block = 0; - let mut i_right_broadcast = 0; - rhs[o_r1..o_r2] - .iter() - .map(|&r| { - let l = unsafe { lhs.get_unchecked(i_in_block) }; - i_right_broadcast += 1; - if i_right_broadcast >= ob.right_broadcast { - i_in_block += 1; - i_right_broadcast = 0; + let mut ys = rhs[o_r1..o_r2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &l) in lhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(l, *v) } - if i_in_block >= ob.len { - i_in_block = 0 - } - f(*l, r) - }) - .collect() + } + } + ys } None => lhs_l .strided_index()