mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Broadcasting performance optimization (cpu) (#182)
* Avoid recomputing the index from scratch each time. * More performance optimisations.
This commit is contained in:
@ -179,4 +179,60 @@ impl Layout {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the contiguous offsets with broadcast if applicable.
|
||||
pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
|
||||
let mut left_broadcast = 1;
|
||||
let mut right_broadcast = 1;
|
||||
let strides = self.stride();
|
||||
let dims = self.dims();
|
||||
let mut start_cont = 0;
|
||||
let mut end_cont = dims.len();
|
||||
for (&s, &d) in strides.iter().zip(dims.iter()) {
|
||||
if s != 0 {
|
||||
break;
|
||||
}
|
||||
start_cont += 1;
|
||||
left_broadcast *= d;
|
||||
}
|
||||
if start_cont == dims.len() {
|
||||
return Some(ContiguousOffsetsWithBroadcast {
|
||||
start: self.start_offset,
|
||||
len: 1,
|
||||
left_broadcast,
|
||||
right_broadcast: 1,
|
||||
});
|
||||
}
|
||||
for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
|
||||
if s != 0 {
|
||||
break;
|
||||
}
|
||||
end_cont -= 1;
|
||||
right_broadcast *= d;
|
||||
}
|
||||
// Check that the inner dims are contiguous
|
||||
let strides = &strides[start_cont..end_cont];
|
||||
let dims = &dims[start_cont..end_cont];
|
||||
let mut len = 1;
|
||||
for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
|
||||
if stride != len {
|
||||
return None;
|
||||
}
|
||||
len *= dim;
|
||||
}
|
||||
Some(ContiguousOffsetsWithBroadcast {
|
||||
start: self.start_offset,
|
||||
len,
|
||||
left_broadcast,
|
||||
right_broadcast,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ContiguousOffsetsWithBroadcast {
|
||||
pub start: usize,
|
||||
pub len: usize,
|
||||
pub left_broadcast: usize,
|
||||
pub right_broadcast: usize,
|
||||
}
|
||||
|
Reference in New Issue
Block a user