Optimize for the unstrided case.

This commit is contained in:
laurent
2023-06-23 15:49:11 +01:00
parent 4c8931d2e4
commit 83e75b3af8

View File

@ -13,6 +13,28 @@ pub enum CpuStorage {
F64(Vec<f64>),
}
// This function maps over two strided index sequences. It supports broadcasting in case
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
fn binary_map<T, F: FnMut((usize, usize)) -> T>(
shape: &Shape,
lhs_stride: &[usize],
rhs_stride: &[usize],
mut f: F,
) -> Vec<T> {
let dims = shape.dims();
if dims.len() == lhs_stride.len() && dims.len() == rhs_stride.len() {
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
(0..shape.elem_count()).map(|i| f((i, i))).collect()
} else {
let lhs_index = StridedIndex::new(dims, lhs_stride);
let rhs_index = StridedIndex::new(dims, rhs_stride);
lhs_index.zip(rhs_index).map(f).collect()
}
} else {
todo!("implement broadcast")
}
}
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
@ -86,40 +108,23 @@ impl CpuStorage {
lhs_stride: &[usize],
rhs_stride: &[usize],
) -> Result<Self> {
let dims = shape.dims();
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
todo!("implement broadcast");
}
// The ggml implementation has different paths based on whether the rhs is contiguous
// or not, for now we only consider the general case but we should benchmark and do the
// same if it helps.
// https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
match (self, rhs) {
(Self::F32(lhs), Self::F32(rhs)) => {
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
let data = lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
.collect();
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
B::f32(lhs[l], rhs[r])
});
Ok(Self::F32(data))
}
(Self::F64(lhs), Self::F64(rhs)) => {
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
let data = lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
.collect();
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
B::f64(lhs[l], rhs[r])
});
Ok(Self::F64(data))
}
(Self::U32(lhs), Self::U32(rhs)) => {
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
let data = lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| B::u32(lhs[lhs_i], rhs[rhs_i]))
.collect();
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
B::u32(lhs[l], rhs[r])
});
Ok(Self::U32(data))
}
_ => {