mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Factor out the gemm bits.
This commit is contained in:
@ -158,6 +158,71 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
}
|
||||
}
|
||||
|
||||
fn matmul_impl<T: 'static + num_traits::Num + Copy>(
|
||||
lhs: &[T],
|
||||
rhs: &[T],
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Vec<T>> {
|
||||
let a_skip: usize = m * k;
|
||||
let b_skip: usize = n * k;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rank = lhs_stride.len();
|
||||
let lhs_cs = lhs_stride[rank - 1];
|
||||
let lhs_rs = lhs_stride[rank - 2];
|
||||
|
||||
let rhs_cs = rhs_stride[rank - 1];
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
if lhs_stride.len() > 2 {
|
||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||
|
||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||
// Temporary error before we support abitrary striding.
|
||||
return Err(Error::UnexpectedStriding);
|
||||
}
|
||||
}
|
||||
|
||||
let dst_shape: Shape = (m, n).into();
|
||||
let dst_strides = dst_shape.stride_contiguous();
|
||||
let dst_rs = dst_strides[0];
|
||||
let dst_cs = dst_strides[1];
|
||||
|
||||
let mut dst = vec![T::zero(); b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
/* m: usize = */ m,
|
||||
/* n: usize = */ n,
|
||||
/* k: usize = */ k,
|
||||
/* dst: *mut T = */ dst_p.as_mut_ptr(),
|
||||
/* dst_cs: isize = */ dst_cs as isize,
|
||||
/* dst_rs: isize = */ dst_rs as isize,
|
||||
/* read_dst: bool = */ false,
|
||||
/* lhs: *const T = */ lhs_p.as_ptr(),
|
||||
/* lhs_cs: isize = */ lhs_cs as isize,
|
||||
/* lhs_rs: isize = */ lhs_rs as isize,
|
||||
/* rhs: *const T = */ rhs_p.as_ptr(),
|
||||
/* rhs_cs: isize = */ rhs_cs as isize,
|
||||
/* rhs_rs: isize = */ rhs_rs as isize,
|
||||
/* alpha: T = */ T::zero(),
|
||||
/* beta: T = */ T::one(),
|
||||
/* conj_dst: bool = */ false,
|
||||
/* conj_lhs: bool = */ false,
|
||||
/* conj_rhs: bool = */ false,
|
||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
impl CpuStorage {
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
@ -593,199 +658,28 @@ impl CpuStorage {
|
||||
pub(crate) fn matmul_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
bmnk: (usize, usize, usize, usize),
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let a_skip: usize = m * k;
|
||||
let b_skip: usize = n * k;
|
||||
let c_skip: usize = m * n;
|
||||
|
||||
let rank = lhs_stride.len();
|
||||
let lhs_cs = lhs_stride[rank - 1];
|
||||
let lhs_rs = lhs_stride[rank - 2];
|
||||
|
||||
let rhs_cs = rhs_stride[rank - 1];
|
||||
let rhs_rs = rhs_stride[rank - 2];
|
||||
|
||||
if lhs_stride.len() > 2 {
|
||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||
|
||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||
// Temporary error before we support abitrary striding.
|
||||
return Err(Error::UnexpectedStriding);
|
||||
}
|
||||
}
|
||||
|
||||
let dst_shape: Shape = (m, n).into();
|
||||
let dst_strides = dst_shape.stride_contiguous();
|
||||
let dst_rs = dst_strides[0];
|
||||
let dst_cs = dst_strides[1];
|
||||
|
||||
match (self, rhs) {
|
||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||
let mut dst = vec![f16::ZERO; b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
// m: usize,
|
||||
m,
|
||||
// n: usize,
|
||||
n,
|
||||
// k: usize,
|
||||
k,
|
||||
// dst: *mut T,
|
||||
dst_p.as_mut_ptr(),
|
||||
// dst_cs: isize,
|
||||
dst_cs as isize,
|
||||
// dst_rs: isize,
|
||||
dst_rs as isize,
|
||||
// read_dst: bool,
|
||||
false,
|
||||
// lhs: *const T,
|
||||
lhs_p.as_ptr(),
|
||||
// lhs_cs: isize,
|
||||
lhs_cs as isize,
|
||||
// lhs_rs: isize,
|
||||
lhs_rs as isize,
|
||||
// rhs: *const T,
|
||||
rhs_p.as_ptr(),
|
||||
// rhs_cs: isize,
|
||||
rhs_cs as isize,
|
||||
// rhs_rs: isize,
|
||||
rhs_rs as isize,
|
||||
// alpha: T,
|
||||
f16::ONE,
|
||||
// beta: T,
|
||||
f16::ONE,
|
||||
// conj_dst: bool,
|
||||
false,
|
||||
// conj_lhs: bool,
|
||||
false,
|
||||
// conj_rhs: bool,
|
||||
true,
|
||||
// parallelism: Parallelism
|
||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::F16(dst))
|
||||
}
|
||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||
let mut dst = vec![0f32; b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
// m: usize,
|
||||
m,
|
||||
// n: usize,
|
||||
n,
|
||||
// k: usize,
|
||||
k,
|
||||
// dst: *mut T,
|
||||
dst_p.as_mut_ptr(),
|
||||
// dst_cs: isize,
|
||||
dst_cs as isize,
|
||||
// dst_rs: isize,
|
||||
dst_rs as isize,
|
||||
// read_dst: bool,
|
||||
false,
|
||||
// lhs: *const T,
|
||||
lhs_p.as_ptr(),
|
||||
// lhs_cs: isize,
|
||||
lhs_cs as isize,
|
||||
// lhs_rs: isize,
|
||||
lhs_rs as isize,
|
||||
// rhs: *const T,
|
||||
rhs_p.as_ptr(),
|
||||
// rhs_cs: isize,
|
||||
rhs_cs as isize,
|
||||
// rhs_rs: isize,
|
||||
rhs_rs as isize,
|
||||
// alpha: T,
|
||||
1f32,
|
||||
// beta: T,
|
||||
1f32,
|
||||
// conj_dst: bool,
|
||||
false,
|
||||
// conj_lhs: bool,
|
||||
false,
|
||||
// conj_rhs: bool,
|
||||
true,
|
||||
// parallelism: Parallelism
|
||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||
)
|
||||
}
|
||||
}
|
||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::F32(dst))
|
||||
}
|
||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||
let mut dst = vec![0f64; b * m * n];
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
gemm(
|
||||
// m: usize,
|
||||
m,
|
||||
// n: usize,
|
||||
n,
|
||||
// k: usize,
|
||||
k,
|
||||
// dst: *mut T,
|
||||
dst_p.as_mut_ptr(),
|
||||
// dst_cs: isize,
|
||||
dst_cs as isize,
|
||||
// dst_rs: isize,
|
||||
dst_rs as isize,
|
||||
// read_dst: bool,
|
||||
false,
|
||||
// lhs: *const T,
|
||||
lhs_p.as_ptr(),
|
||||
// lhs_cs: isize,
|
||||
lhs_cs as isize,
|
||||
// lhs_rs: isize,
|
||||
lhs_rs as isize,
|
||||
// rhs: *const T,
|
||||
rhs_p.as_ptr(),
|
||||
// rhs_cs: isize,
|
||||
rhs_cs as isize,
|
||||
// rhs_rs: isize,
|
||||
rhs_rs as isize,
|
||||
// alpha: T,
|
||||
1f64,
|
||||
// beta: T,
|
||||
1f64,
|
||||
// conj_dst: bool,
|
||||
false,
|
||||
// conj_lhs: bool,
|
||||
false,
|
||||
// conj_rhs: bool,
|
||||
true,
|
||||
// parallelism: Parallelism
|
||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||
)
|
||||
}
|
||||
}
|
||||
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::F64(dst))
|
||||
}
|
||||
_ => {
|
||||
// This should be covered by the dtype check above.
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: rhs.dtype(),
|
||||
op: "matmul",
|
||||
})
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: rhs.dtype(),
|
||||
op: "matmul",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user