mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Factor out the gemm bits.
This commit is contained in:
@ -158,6 +158,71 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matmul_impl<T: 'static + num_traits::Num + Copy>(
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
let a_skip: usize = m * k;
|
||||||
|
let b_skip: usize = n * k;
|
||||||
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
let lhs_cs = lhs_stride[rank - 1];
|
||||||
|
let lhs_rs = lhs_stride[rank - 2];
|
||||||
|
|
||||||
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
|
if lhs_stride.len() > 2 {
|
||||||
|
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||||
|
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||||
|
|
||||||
|
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||||
|
// Temporary error before we support abitrary striding.
|
||||||
|
return Err(Error::UnexpectedStriding);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst_shape: Shape = (m, n).into();
|
||||||
|
let dst_strides = dst_shape.stride_contiguous();
|
||||||
|
let dst_rs = dst_strides[0];
|
||||||
|
let dst_cs = dst_strides[1];
|
||||||
|
|
||||||
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
|
unsafe {
|
||||||
|
gemm(
|
||||||
|
/* m: usize = */ m,
|
||||||
|
/* n: usize = */ n,
|
||||||
|
/* k: usize = */ k,
|
||||||
|
/* dst: *mut T = */ dst_p.as_mut_ptr(),
|
||||||
|
/* dst_cs: isize = */ dst_cs as isize,
|
||||||
|
/* dst_rs: isize = */ dst_rs as isize,
|
||||||
|
/* read_dst: bool = */ false,
|
||||||
|
/* lhs: *const T = */ lhs_p.as_ptr(),
|
||||||
|
/* lhs_cs: isize = */ lhs_cs as isize,
|
||||||
|
/* lhs_rs: isize = */ lhs_rs as isize,
|
||||||
|
/* rhs: *const T = */ rhs_p.as_ptr(),
|
||||||
|
/* rhs_cs: isize = */ rhs_cs as isize,
|
||||||
|
/* rhs_rs: isize = */ rhs_rs as isize,
|
||||||
|
/* alpha: T = */ T::zero(),
|
||||||
|
/* beta: T = */ T::one(),
|
||||||
|
/* conj_dst: bool = */ false,
|
||||||
|
/* conj_lhs: bool = */ false,
|
||||||
|
/* conj_rhs: bool = */ false,
|
||||||
|
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
|
||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
@ -593,199 +658,28 @@ impl CpuStorage {
|
|||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
bmnk: (usize, usize, usize, usize),
|
||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let a_skip: usize = m * k;
|
|
||||||
let b_skip: usize = n * k;
|
|
||||||
let c_skip: usize = m * n;
|
|
||||||
|
|
||||||
let rank = lhs_stride.len();
|
|
||||||
let lhs_cs = lhs_stride[rank - 1];
|
|
||||||
let lhs_rs = lhs_stride[rank - 2];
|
|
||||||
|
|
||||||
let rhs_cs = rhs_stride[rank - 1];
|
|
||||||
let rhs_rs = rhs_stride[rank - 2];
|
|
||||||
|
|
||||||
if lhs_stride.len() > 2 {
|
|
||||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
|
||||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
|
||||||
|
|
||||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
|
||||||
// Temporary error before we support abitrary striding.
|
|
||||||
return Err(Error::UnexpectedStriding);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
|
||||||
let dst_strides = dst_shape.stride_contiguous();
|
|
||||||
let dst_rs = dst_strides[0];
|
|
||||||
let dst_cs = dst_strides[1];
|
|
||||||
|
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||||
let mut dst = vec![f16::ZERO; b * m * n];
|
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
for step in 0..b {
|
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
|
||||||
unsafe {
|
|
||||||
gemm(
|
|
||||||
// m: usize,
|
|
||||||
m,
|
|
||||||
// n: usize,
|
|
||||||
n,
|
|
||||||
// k: usize,
|
|
||||||
k,
|
|
||||||
// dst: *mut T,
|
|
||||||
dst_p.as_mut_ptr(),
|
|
||||||
// dst_cs: isize,
|
|
||||||
dst_cs as isize,
|
|
||||||
// dst_rs: isize,
|
|
||||||
dst_rs as isize,
|
|
||||||
// read_dst: bool,
|
|
||||||
false,
|
|
||||||
// lhs: *const T,
|
|
||||||
lhs_p.as_ptr(),
|
|
||||||
// lhs_cs: isize,
|
|
||||||
lhs_cs as isize,
|
|
||||||
// lhs_rs: isize,
|
|
||||||
lhs_rs as isize,
|
|
||||||
// rhs: *const T,
|
|
||||||
rhs_p.as_ptr(),
|
|
||||||
// rhs_cs: isize,
|
|
||||||
rhs_cs as isize,
|
|
||||||
// rhs_rs: isize,
|
|
||||||
rhs_rs as isize,
|
|
||||||
// alpha: T,
|
|
||||||
f16::ONE,
|
|
||||||
// beta: T,
|
|
||||||
f16::ONE,
|
|
||||||
// conj_dst: bool,
|
|
||||||
false,
|
|
||||||
// conj_lhs: bool,
|
|
||||||
false,
|
|
||||||
// conj_rhs: bool,
|
|
||||||
true,
|
|
||||||
// parallelism: Parallelism
|
|
||||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self::F16(dst))
|
Ok(Self::F16(dst))
|
||||||
}
|
}
|
||||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||||
let mut dst = vec![0f32; b * m * n];
|
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
for step in 0..b {
|
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
|
||||||
unsafe {
|
|
||||||
gemm(
|
|
||||||
// m: usize,
|
|
||||||
m,
|
|
||||||
// n: usize,
|
|
||||||
n,
|
|
||||||
// k: usize,
|
|
||||||
k,
|
|
||||||
// dst: *mut T,
|
|
||||||
dst_p.as_mut_ptr(),
|
|
||||||
// dst_cs: isize,
|
|
||||||
dst_cs as isize,
|
|
||||||
// dst_rs: isize,
|
|
||||||
dst_rs as isize,
|
|
||||||
// read_dst: bool,
|
|
||||||
false,
|
|
||||||
// lhs: *const T,
|
|
||||||
lhs_p.as_ptr(),
|
|
||||||
// lhs_cs: isize,
|
|
||||||
lhs_cs as isize,
|
|
||||||
// lhs_rs: isize,
|
|
||||||
lhs_rs as isize,
|
|
||||||
// rhs: *const T,
|
|
||||||
rhs_p.as_ptr(),
|
|
||||||
// rhs_cs: isize,
|
|
||||||
rhs_cs as isize,
|
|
||||||
// rhs_rs: isize,
|
|
||||||
rhs_rs as isize,
|
|
||||||
// alpha: T,
|
|
||||||
1f32,
|
|
||||||
// beta: T,
|
|
||||||
1f32,
|
|
||||||
// conj_dst: bool,
|
|
||||||
false,
|
|
||||||
// conj_lhs: bool,
|
|
||||||
false,
|
|
||||||
// conj_rhs: bool,
|
|
||||||
true,
|
|
||||||
// parallelism: Parallelism
|
|
||||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Self::F32(dst))
|
Ok(Self::F32(dst))
|
||||||
}
|
}
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||||
let mut dst = vec![0f64; b * m * n];
|
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
for step in 0..b {
|
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
|
||||||
unsafe {
|
|
||||||
gemm(
|
|
||||||
// m: usize,
|
|
||||||
m,
|
|
||||||
// n: usize,
|
|
||||||
n,
|
|
||||||
// k: usize,
|
|
||||||
k,
|
|
||||||
// dst: *mut T,
|
|
||||||
dst_p.as_mut_ptr(),
|
|
||||||
// dst_cs: isize,
|
|
||||||
dst_cs as isize,
|
|
||||||
// dst_rs: isize,
|
|
||||||
dst_rs as isize,
|
|
||||||
// read_dst: bool,
|
|
||||||
false,
|
|
||||||
// lhs: *const T,
|
|
||||||
lhs_p.as_ptr(),
|
|
||||||
// lhs_cs: isize,
|
|
||||||
lhs_cs as isize,
|
|
||||||
// lhs_rs: isize,
|
|
||||||
lhs_rs as isize,
|
|
||||||
// rhs: *const T,
|
|
||||||
rhs_p.as_ptr(),
|
|
||||||
// rhs_cs: isize,
|
|
||||||
rhs_cs as isize,
|
|
||||||
// rhs_rs: isize,
|
|
||||||
rhs_rs as isize,
|
|
||||||
// alpha: T,
|
|
||||||
1f64,
|
|
||||||
// beta: T,
|
|
||||||
1f64,
|
|
||||||
// conj_dst: bool,
|
|
||||||
false,
|
|
||||||
// conj_lhs: bool,
|
|
||||||
false,
|
|
||||||
// conj_rhs: bool,
|
|
||||||
true,
|
|
||||||
// parallelism: Parallelism
|
|
||||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Self::F64(dst))
|
Ok(Self::F64(dst))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
// This should be covered by the dtype check above.
|
lhs: self.dtype(),
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
rhs: rhs.dtype(),
|
||||||
lhs: self.dtype(),
|
op: "matmul",
|
||||||
rhs: rhs.dtype(),
|
}),
|
||||||
op: "matmul",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user