Performance improvement. (#181)

This commit is contained in:
Laurent Mazare
2023-07-17 11:07:14 +01:00
committed by GitHub
parent 28e1c07304
commit 5b1c0bc9be

View File

@ -66,6 +66,7 @@ struct WCond<'a>(&'a [u32], &'a Layout);
impl<'a> Map2 for WCond<'a> { impl<'a> Map2 for WCond<'a> {
const OP: &'static str = "where"; const OP: &'static str = "where";
#[inline(always)]
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> { fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
let vs = match ( let vs = match (
self.1.contiguous_offsets(), self.1.contiguous_offsets(),
@ -116,18 +117,18 @@ impl<'a> Map1 for Sum<'a> {
} }
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> { fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
let mut result = vec![];
result.reserve(layout.shape().elem_count());
match layout.strided_blocks() { match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => { crate::StridedBlocks::SingleBlock { start_offset, len } => vs
for &v in vs[start_offset..start_offset + len].iter() { [start_offset..start_offset + len]
result.push(f(v)) .iter()
} .map(|&v| f(v))
} .collect(),
crate::StridedBlocks::MultipleBlocks { crate::StridedBlocks::MultipleBlocks {
block_start_index, block_start_index,
block_len, block_len,
} => { } => {
let mut result = vec![];
result.reserve(layout.shape().elem_count());
// Specialize the case where block_len is one to avoid the second loop. // Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 { if block_len == 1 {
for index in block_start_index { for index in block_start_index {
@ -142,10 +143,10 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
} }
} }
} }
}
}
result result
} }
}
}
// This function maps over two strided index sequences. // This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>( fn binary_map<T: Copy, F: FnMut(T, T) -> T>(