Add mkl support for matrix multiply. (#86)

* Fix some rebase issues.

* Use mkl instead.

* Use mkl in bert.

* Add the optional mkl feature.

* Conditional compilation based on the mkl feature.

* Add more mkl support.
This commit is contained in:
Laurent Mazare
2023-07-06 11:05:05 +01:00
committed by GitHub
parent cd230d26fe
commit c297a50960
9 changed files with 118 additions and 3 deletions

View File

@ -1,6 +1,5 @@
use crate::op::{BinaryOp, UnaryOp};
use crate::{DType, Error, Layout, Result, Shape, WithDType};
use gemm::{gemm, Parallelism};
use half::{bf16, f16};
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -264,6 +263,8 @@ struct MatMul((usize, usize, usize, usize));
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
#[cfg(not(feature = "mkl"))]
fn f<T: 'static + num_traits::Num + Copy>(
&self,
lhs: &[T],
@ -271,6 +272,7 @@ impl Map2 for MatMul {
rhs: &[T],
rhs_l: &Layout,
) -> Result<Vec<T>> {
use gemm::{gemm, Parallelism};
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
@ -346,6 +348,98 @@ impl Map2 for MatMul {
}
Ok(dst)
}
#[cfg(feature = "mkl")]
fn f<T: 'static + num_traits::Num + Copy>(
&self,
lhs: &[T],
lhs_l: &Layout,
rhs: &[T],
rhs_l: &Layout,
) -> Result<Vec<T>> {
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
let cfg = crate::cuda_backend::gemm_config(1f32, 0f32, (b, m, n, k), lhs_l, rhs_l)?;
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(Error::UnexpectedStriding {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
})?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(Error::UnexpectedStriding {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
})?,
};
let c_skip: usize = m * n;
let mut dst = vec![T::zero(); b * m * n];
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let gemm = cfg.gemm;
let a = rhs_p.as_ptr() as *const f32;
let b = lhs_p.as_ptr() as *const f32;
let c = dst_p.as_mut_ptr() as *mut f32;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
let transa = match gemm.transa {
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N',
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T',
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C',
_ => todo!(),
};
let transb = match gemm.transb {
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_N => b'N',
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_T => b'T',
cudarc::cublas::sys::cublasOperation_t::CUBLAS_OP_C => b'C',
_ => todo!(),
};
blas::sgemm(
transa, transb, gemm.m, gemm.n, gemm.k, gemm.alpha, a, gemm.lda, b, gemm.ldb,
gemm.beta, c, gemm.ldc,
)
// gemm(
// /* m: usize = */ m,
// /* n: usize = */ n,
// /* k: usize = */ k,
// /* dst: *mut T = */ dst_p.as_mut_ptr(),
// /* dst_cs: isize = */ dst_cs as isize,
// /* dst_rs: isize = */ dst_rs as isize,
// /* read_dst: bool = */ false,
// /* lhs: *const T = */ lhs_p.as_ptr(),
// /* lhs_cs: isize = */ lhs_cs as isize,
// /* lhs_rs: isize = */ lhs_rs as isize,
// /* rhs: *const T = */ rhs_p.as_ptr(),
// /* rhs_cs: isize = */ rhs_cs as isize,
// /* rhs_rs: isize = */ rhs_rs as isize,
// /* alpha: T = */ T::zero(),
// /* beta: T = */ T::one(),
// /* conj_dst: bool = */ false,
// /* conj_lhs: bool = */ false,
// /* conj_rhs: bool = */ false,
// parallelism,
// )
}
}
Ok(dst)
}
}
fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize) -> Result<()> {