mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fix some cpu issue.
This commit is contained in:
@ -81,22 +81,23 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut
|
||||
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
mut f: F,
|
||||
) -> Vec<T> {
|
||||
let shape = lhs_layout.shape();
|
||||
if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() {
|
||||
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
||||
} else {
|
||||
let lhs_index = lhs_layout.strided_index();
|
||||
let rhs_index = rhs_layout.strided_index();
|
||||
lhs_index
|
||||
.zip(rhs_index)
|
||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||
.iter()
|
||||
.zip(rhs[o_r1..o_r2].iter())
|
||||
.map(|(&l, &r)| f(l, r))
|
||||
.collect(),
|
||||
_ => lhs_l
|
||||
.strided_index()
|
||||
.zip(rhs_l.strided_index())
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect()
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -151,15 +152,17 @@ fn matmul<T: 'static + num_traits::Num + Copy>(
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
let lhs = &lhs[lhs_l.start_offset()..];
|
||||
let rhs = &rhs[rhs_l.start_offset()..];
|
||||
let a_skip: usize = m * k;
|
||||
let b_skip: usize = n * k;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let lhs_stride = lhs_layout.stride();
|
||||
let rhs_stride = rhs_layout.stride();
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
let rank = lhs_stride.len();
|
||||
let lhs_cs = lhs_stride[rank - 1];
|
||||
let lhs_rs = lhs_stride[rank - 2];
|
||||
@ -509,28 +512,28 @@ impl CpuStorage {
|
||||
pub(crate) fn binary_impl<B: BinaryOp>(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(Self::BF16(lhs), Self::BF16(rhs)) => {
|
||||
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::bf16);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
(Self::F16(lhs), Self::F16(rhs)) => {
|
||||
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f16);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f32);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
||||
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f64);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(lhs), Self::U32(rhs)) => {
|
||||
let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::u32);
|
||||
let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => {
|
||||
@ -622,20 +625,20 @@ impl CpuStorage {
|
||||
&self,
|
||||
rhs: &Self,
|
||||
bmnk: (usize, usize, usize, usize),
|
||||
lhs_layout: &Layout,
|
||||
rhs_layout: &Layout,
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
match (self, rhs) {
|
||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
||||
Ok(Self::F16(dst))
|
||||
}
|
||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
||||
Ok(Self::F32(dst))
|
||||
}
|
||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?;
|
||||
let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?;
|
||||
Ok(Self::F64(dst))
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
|
Reference in New Issue
Block a user