mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Performance improvement. (#181)
This commit is contained in:
@ -66,6 +66,7 @@ struct WCond<'a>(&'a [u32], &'a Layout);
|
|||||||
|
|
||||||
impl<'a> Map2 for WCond<'a> {
|
impl<'a> Map2 for WCond<'a> {
|
||||||
const OP: &'static str = "where";
|
const OP: &'static str = "where";
|
||||||
|
#[inline(always)]
|
||||||
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
||||||
let vs = match (
|
let vs = match (
|
||||||
self.1.contiguous_offsets(),
|
self.1.contiguous_offsets(),
|
||||||
@ -116,18 +117,18 @@ impl<'a> Map1 for Sum<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||||
let mut result = vec![];
|
|
||||||
result.reserve(layout.shape().elem_count());
|
|
||||||
match layout.strided_blocks() {
|
match layout.strided_blocks() {
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||||
for &v in vs[start_offset..start_offset + len].iter() {
|
[start_offset..start_offset + len]
|
||||||
result.push(f(v))
|
.iter()
|
||||||
}
|
.map(|&v| f(v))
|
||||||
}
|
.collect(),
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
block_start_index,
|
block_start_index,
|
||||||
block_len,
|
block_len,
|
||||||
} => {
|
} => {
|
||||||
|
let mut result = vec![];
|
||||||
|
result.reserve(layout.shape().elem_count());
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
if block_len == 1 {
|
if block_len == 1 {
|
||||||
for index in block_start_index {
|
for index in block_start_index {
|
||||||
@ -142,10 +143,10 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// This function maps over two strided index sequences.
|
// This function maps over two strided index sequences.
|
||||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||||
|
Reference in New Issue
Block a user