More efficient CPU broadcasting implementation.

This commit is contained in:
laurent
2023-06-23 16:23:12 +01:00
parent 10a5807dff
commit bcfbb1dca1

View File

@ -23,10 +23,12 @@ fn unary_map<T, F: FnMut(usize) -> T>(shape: &Shape, stride: &[usize], f: F) ->
// This function maps over two strided index sequences. It supports broadcasting in case
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
fn binary_map<T, F: FnMut((usize, usize)) -> T>(
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
lhs: &[T],
rhs: &[T],
mut f: F,
) -> Vec<T> {
let dims = shape.dims();
@ -35,21 +37,26 @@ fn binary_map<T, F: FnMut((usize, usize)) -> T>(
let elem_count = shape.elem_count();
if broadcast_ldims == 0 && broadcast_rdims == 0 {
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
(0..shape.elem_count()).map(|i| f((i, i))).collect()
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
} else {
let lhs_index = StridedIndex::new(dims, lhs_stride);
let rhs_index = StridedIndex::new(dims, rhs_stride);
lhs_index.zip(rhs_index).map(f).collect()
lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect()
}
} else if broadcast_rdims == 0 {
let mut res = Vec::new();
res.reserve(elem_count);
let lhs_index: Vec<_> = StridedIndex::new(dims, lhs_stride).collect();
let lhs_v: Vec<T> = StridedIndex::new(dims, lhs_stride)
.map(|i| lhs[i])
.collect();
let mut i = 0;
for rhs_i in StridedIndex::new(dims, rhs_stride) {
res.push(f((lhs_index[i], rhs_i)));
res.push(f(lhs_v[i], rhs[rhs_i]));
i += 1;
if i >= lhs_index.len() {
if i >= lhs_v.len() {
i = 0
}
}
@ -57,12 +64,14 @@ fn binary_map<T, F: FnMut((usize, usize)) -> T>(
} else if broadcast_ldims == 0 {
let mut res = Vec::new();
res.reserve(elem_count);
let rhs_index: Vec<_> = StridedIndex::new(dims, rhs_stride).collect();
let rhs_v: Vec<T> = StridedIndex::new(dims, rhs_stride)
.map(|i| rhs[i])
.collect();
let mut i = 0;
for lhs_i in StridedIndex::new(dims, lhs_stride) {
res.push(f((lhs_i, rhs_index[i])));
res.push(f(lhs[lhs_i], rhs_v[i]));
i += 1;
if i >= rhs_index.len() {
if i >= rhs_v.len() {
i = 0
}
}
@ -144,21 +153,15 @@ impl CpuStorage {
) -> Result<Self> {
match (self, rhs) {
(Self::F32(lhs), Self::F32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
B::f32(lhs[l], rhs[r])
});
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32);
Ok(Self::F32(data))
}
(Self::F64(lhs), Self::F64(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
B::f64(lhs[l], rhs[r])
});
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64);
Ok(Self::F64(data))
}
(Self::U32(lhs), Self::U32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
B::u32(lhs[l], rhs[r])
});
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32);
Ok(Self::U32(data))
}
_ => {