Broadcasting performance optimization (cpu) (#182)

* Avoid recomputing the index from scratch each time.

* More performance optimisations.
This commit is contained in:
Laurent Mazare
2023-07-17 13:41:09 +01:00
committed by GitHub
parent 5b1c0bc9be
commit acb2f90469
4 changed files with 125 additions and 11 deletions

View File

@ -179,4 +179,60 @@ impl Layout {
}
}
}
// Returns the contiguous offsets with broadcast if applicable.
pub(crate) fn offsets_b(&self) -> Option<ContiguousOffsetsWithBroadcast> {
let mut left_broadcast = 1;
let mut right_broadcast = 1;
let strides = self.stride();
let dims = self.dims();
let mut start_cont = 0;
let mut end_cont = dims.len();
for (&s, &d) in strides.iter().zip(dims.iter()) {
if s != 0 {
break;
}
start_cont += 1;
left_broadcast *= d;
}
if start_cont == dims.len() {
return Some(ContiguousOffsetsWithBroadcast {
start: self.start_offset,
len: 1,
left_broadcast,
right_broadcast: 1,
});
}
for (&s, &d) in strides.iter().zip(dims.iter()).rev() {
if s != 0 {
break;
}
end_cont -= 1;
right_broadcast *= d;
}
// Check that the inner dims are contiguous
let strides = &strides[start_cont..end_cont];
let dims = &dims[start_cont..end_cont];
let mut len = 1;
for (&stride, &dim) in strides.iter().zip(dims.iter()).rev() {
if stride != len {
return None;
}
len *= dim;
}
Some(ContiguousOffsetsWithBroadcast {
start: self.start_offset,
len,
left_broadcast,
right_broadcast,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ContiguousOffsetsWithBroadcast {
pub start: usize,
pub len: usize,
pub left_broadcast: usize,
pub right_broadcast: usize,
}