mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Optimize for the unstrided case.
This commit is contained in:
@ -13,6 +13,28 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function maps over two strided index sequences. It supports broadcasting in case
|
||||||
|
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
|
||||||
|
fn binary_map<T, F: FnMut((usize, usize)) -> T>(
|
||||||
|
shape: &Shape,
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let dims = shape.dims();
|
||||||
|
if dims.len() == lhs_stride.len() && dims.len() == rhs_stride.len() {
|
||||||
|
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
||||||
|
(0..shape.elem_count()).map(|i| f((i, i))).collect()
|
||||||
|
} else {
|
||||||
|
let lhs_index = StridedIndex::new(dims, lhs_stride);
|
||||||
|
let rhs_index = StridedIndex::new(dims, rhs_stride);
|
||||||
|
lhs_index.zip(rhs_index).map(f).collect()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
todo!("implement broadcast")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
@ -86,40 +108,23 @@ impl CpuStorage {
|
|||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let dims = shape.dims();
|
|
||||||
if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() {
|
|
||||||
todo!("implement broadcast");
|
|
||||||
}
|
|
||||||
// The ggml implementation has different paths based on whether the rhs is contiguous
|
|
||||||
// or not, for now we only consider the general case but we should benchmark and do the
|
|
||||||
// same if it helps.
|
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895
|
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
|
||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
B::f32(lhs[l], rhs[r])
|
||||||
let data = lhs_index
|
});
|
||||||
.zip(rhs_index)
|
|
||||||
.map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect();
|
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
(Self::F64(lhs), Self::F64(rhs)) => {
|
||||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
|
||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
B::f64(lhs[l], rhs[r])
|
||||||
let data = lhs_index
|
});
|
||||||
.zip(rhs_index)
|
|
||||||
.map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect();
|
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
(Self::U32(lhs), Self::U32(rhs)) => {
|
(Self::U32(lhs), Self::U32(rhs)) => {
|
||||||
let lhs_index = StridedIndex::new(shape.dims(), lhs_stride);
|
let data = binary_map(shape, lhs_stride, rhs_stride, |(l, r)| {
|
||||||
let rhs_index = StridedIndex::new(shape.dims(), rhs_stride);
|
B::u32(lhs[l], rhs[r])
|
||||||
let data = lhs_index
|
});
|
||||||
.zip(rhs_index)
|
|
||||||
.map(|(lhs_i, rhs_i)| B::u32(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect();
|
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
Reference in New Issue
Block a user