mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Broadcast cpu implementation.
This commit is contained in:
@ -13,6 +13,14 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn unary_map<T, F: FnMut(usize) -> T>(shape: &Shape, stride: &[usize], f: F) -> Vec<T> {
|
||||||
|
if shape.is_contiguous(stride) {
|
||||||
|
(0..shape.elem_count()).map(f).collect()
|
||||||
|
} else {
|
||||||
|
StridedIndex::new(shape.dims(), stride).map(f).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// This function maps over two strided index sequences. It supports broadcasting in case
|
// This function maps over two strided index sequences. It supports broadcasting in case
|
||||||
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
|
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
|
||||||
fn binary_map<T, F: FnMut((usize, usize)) -> T>(
|
fn binary_map<T, F: FnMut((usize, usize)) -> T>(
|
||||||
@ -22,7 +30,10 @@ fn binary_map<T, F: FnMut((usize, usize)) -> T>(
|
|||||||
mut f: F,
|
mut f: F,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
if dims.len() == lhs_stride.len() && dims.len() == rhs_stride.len() {
|
let broadcast_ldims = dims.len() - lhs_stride.len();
|
||||||
|
let broadcast_rdims = dims.len() - rhs_stride.len();
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
if broadcast_ldims == 0 && broadcast_rdims == 0 {
|
||||||
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
||||||
(0..shape.elem_count()).map(|i| f((i, i))).collect()
|
(0..shape.elem_count()).map(|i| f((i, i))).collect()
|
||||||
} else {
|
} else {
|
||||||
@ -30,8 +41,34 @@ fn binary_map<T, F: FnMut((usize, usize)) -> T>(
|
|||||||
let rhs_index = StridedIndex::new(dims, rhs_stride);
|
let rhs_index = StridedIndex::new(dims, rhs_stride);
|
||||||
lhs_index.zip(rhs_index).map(f).collect()
|
lhs_index.zip(rhs_index).map(f).collect()
|
||||||
}
|
}
|
||||||
|
} else if broadcast_rdims == 0 {
|
||||||
|
let mut res = Vec::new();
|
||||||
|
res.reserve(elem_count);
|
||||||
|
let lhs_index: Vec<_> = StridedIndex::new(dims, lhs_stride).collect();
|
||||||
|
let mut i = 0;
|
||||||
|
for rhs_i in StridedIndex::new(dims, rhs_stride) {
|
||||||
|
res.push(f((lhs_index[i], rhs_i)));
|
||||||
|
i += 1;
|
||||||
|
if i >= lhs_index.len() {
|
||||||
|
i = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res
|
||||||
|
} else if broadcast_ldims == 0 {
|
||||||
|
let mut res = Vec::new();
|
||||||
|
res.reserve(elem_count);
|
||||||
|
let rhs_index: Vec<_> = StridedIndex::new(dims, rhs_stride).collect();
|
||||||
|
let mut i = 0;
|
||||||
|
for lhs_i in StridedIndex::new(dims, lhs_stride) {
|
||||||
|
res.push(f((lhs_i, rhs_index[i])));
|
||||||
|
i += 1;
|
||||||
|
if i >= rhs_index.len() {
|
||||||
|
i = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
res
|
||||||
} else {
|
} else {
|
||||||
todo!("implement broadcast")
|
panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,22 +98,19 @@ impl CpuStorage {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Self::U32(storage) => {
|
Self::U32(storage) => {
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
|
||||||
let mul = mul as u32;
|
let mul = mul as u32;
|
||||||
let add = add as u32;
|
let add = add as u32;
|
||||||
let data = index.map(|i| storage[i] * mul + add).collect();
|
let data = unary_map(shape, stride, |i| storage[i] * mul + add);
|
||||||
Ok(Self::U32(data))
|
Ok(Self::U32(data))
|
||||||
}
|
}
|
||||||
Self::F32(storage) => {
|
Self::F32(storage) => {
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
|
||||||
let mul = mul as f32;
|
let mul = mul as f32;
|
||||||
let add = add as f32;
|
let add = add as f32;
|
||||||
let data = index.map(|i| storage[i] * mul + add).collect();
|
let data = unary_map(shape, stride, |i| storage[i] * mul + add);
|
||||||
Ok(Self::F32(data))
|
Ok(Self::F32(data))
|
||||||
}
|
}
|
||||||
Self::F64(storage) => {
|
Self::F64(storage) => {
|
||||||
let index = StridedIndex::new(shape.dims(), stride);
|
let data = unary_map(shape, stride, |i| storage[i] * mul + add);
|
||||||
let data = index.map(|i| storage[i] * mul + add).collect();
|
|
||||||
Ok(Self::F64(data))
|
Ok(Self::F64(data))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user