mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Broadcasting performance optimization (cpu) (#182)
* Avoid recomputing the index from scratch each time. * More performance optimisations.
This commit is contained in:
@ -162,6 +162,64 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
.zip(rhs[o_r1..o_r2].iter())
|
.zip(rhs[o_r1..o_r2].iter())
|
||||||
.map(|(&l, &r)| f(l, r))
|
.map(|(&l, &r)| f(l, r))
|
||||||
.collect(),
|
.collect(),
|
||||||
|
(Some((o_l1, o_l2)), None) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match rhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.map(|&l| {
|
||||||
|
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(l, *r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, Some((o_r1, o_r2))) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match lhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
rhs[o_r1..o_r2]
|
||||||
|
.iter()
|
||||||
|
.map(|&r| {
|
||||||
|
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(*l, r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => lhs_l
|
_ => lhs_l
|
||||||
.strided_index()
|
.strided_index()
|
||||||
.zip(rhs_l.strided_index())
|
.zip(rhs_l.strided_index())
|
||||||
|
@ -179,4 +179,60 @@ impl Layout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the contiguous offsets with broadcast if applicable.
|
||||||
|
pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
|
||||||
|
let mut left_broadcast = 1;
|
||||||
|
let mut right_broadcast = 1;
|
||||||
|
let strides = self.stride();
|
||||||
|
let dims = self.dims();
|
||||||
|
let mut start_cont = 0;
|
||||||
|
let mut end_cont = dims.len();
|
||||||
|
for (&s, &d) in strides.iter().zip(dims.iter()) {
|
||||||
|
if s != 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
start_cont += 1;
|
||||||
|
left_broadcast *= d;
|
||||||
|
}
|
||||||
|
if start_cont == dims.len() {
|
||||||
|
return Some(ContiguousOffsetsWithBroadcast {
|
||||||
|
start: self.start_offset,
|
||||||
|
len: 1,
|
||||||
|
left_broadcast,
|
||||||
|
right_broadcast: 1,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
|
||||||
|
if s != 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
end_cont -= 1;
|
||||||
|
right_broadcast *= d;
|
||||||
|
}
|
||||||
|
// Check that the inner dims are contiguous
|
||||||
|
let strides = &strides[start_cont..end_cont];
|
||||||
|
let dims = &dims[start_cont..end_cont];
|
||||||
|
let mut len = 1;
|
||||||
|
for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
|
||||||
|
if stride != len {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
len *= dim;
|
||||||
|
}
|
||||||
|
Some(ContiguousOffsetsWithBroadcast {
|
||||||
|
start: self.start_offset,
|
||||||
|
len,
|
||||||
|
left_broadcast,
|
||||||
|
right_broadcast,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct ContiguousOffsetsWithBroadcast {
|
||||||
|
pub start: usize,
|
||||||
|
pub len: usize,
|
||||||
|
pub left_broadcast: usize,
|
||||||
|
pub right_broadcast: usize,
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ mod dtype;
|
|||||||
mod dummy_cuda_backend;
|
mod dummy_cuda_backend;
|
||||||
mod error;
|
mod error;
|
||||||
mod indexer;
|
mod indexer;
|
||||||
mod layout;
|
pub mod layout;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
mod mkl;
|
mod mkl;
|
||||||
pub mod npy;
|
pub mod npy;
|
||||||
|
@ -8,7 +8,6 @@ pub struct StridedIndex<'a> {
|
|||||||
multi_index: Vec<usize>,
|
multi_index: Vec<usize>,
|
||||||
dims: &'a [usize],
|
dims: &'a [usize],
|
||||||
stride: &'a [usize],
|
stride: &'a [usize],
|
||||||
start_offset: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> StridedIndex<'a> {
|
impl<'a> StridedIndex<'a> {
|
||||||
@ -25,7 +24,6 @@ impl<'a> StridedIndex<'a> {
|
|||||||
multi_index: vec![0; dims.len()],
|
multi_index: vec![0; dims.len()],
|
||||||
dims,
|
dims,
|
||||||
stride,
|
stride,
|
||||||
start_offset,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,24 +41,26 @@ impl<'a> Iterator for StridedIndex<'a> {
|
|||||||
Some(storage_index) => storage_index,
|
Some(storage_index) => storage_index,
|
||||||
};
|
};
|
||||||
let mut updated = false;
|
let mut updated = false;
|
||||||
for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() {
|
let mut next_storage_index = storage_index;
|
||||||
|
for ((multi_i, max_i), stride_i) in self
|
||||||
|
.multi_index
|
||||||
|
.iter_mut()
|
||||||
|
.zip(self.dims.iter())
|
||||||
|
.zip(self.stride.iter())
|
||||||
|
.rev()
|
||||||
|
{
|
||||||
let next_i = *multi_i + 1;
|
let next_i = *multi_i + 1;
|
||||||
if next_i < *max_i {
|
if next_i < *max_i {
|
||||||
*multi_i = next_i;
|
*multi_i = next_i;
|
||||||
updated = true;
|
updated = true;
|
||||||
|
next_storage_index += stride_i;
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
|
next_storage_index -= *multi_i * stride_i;
|
||||||
*multi_i = 0
|
*multi_i = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.next_storage_index = if updated {
|
self.next_storage_index = if updated {
|
||||||
let next_storage_index = self
|
|
||||||
.multi_index
|
|
||||||
.iter()
|
|
||||||
.zip(self.stride.iter())
|
|
||||||
.map(|(&x, &y)| x * y)
|
|
||||||
.sum::<usize>()
|
|
||||||
+ self.start_offset;
|
|
||||||
Some(next_storage_index)
|
Some(next_storage_index)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
Reference in New Issue
Block a user