Process unary functions per block (#180)

* Process unary functions per block.

* Add some inline hints.
This commit is contained in:
Laurent Mazare
2023-07-17 10:22:33 +01:00
committed by GitHub
parent 2a74019ec6
commit 28e1c07304
2 changed files with 53 additions and 3 deletions

View File

@ -98,6 +98,7 @@ struct Sum<'a> {
}
impl<'a> Map1 for Sum<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_layout: &Layout) -> Result<Vec<T>> {
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
@ -115,10 +116,35 @@ impl<'a> Map1 for Sum<'a> {
}
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
match layout.contiguous_offsets() {
Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(),
None => layout.strided_index().map(|i| f(vs[i])).collect(),
let mut result = vec![];
result.reserve(layout.shape().elem_count());
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => {
for &v in vs[start_offset..start_offset + len].iter() {
result.push(f(v))
}
}
crate::StridedBlocks::MultipleBlocks {
block_start_index,
block_len,
} => {
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
}
}
result
}
// This function maps over two strided index sequences.