Remove the dependency to blas and use mkl directly. (#125)

This commit is contained in:
Laurent Mazare
2023-07-10 15:52:03 +01:00
committed by GitHub
parent 221b1aff65
commit 548b1df7ea
4 changed files with 190 additions and 4 deletions

View File

@ -416,6 +416,36 @@ impl Map2 for MatMul {
let mut dst = vec![T::zero(); b * m * n];
match T::DTYPE {
DType::F16 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let a = rhs_p.as_ptr() as *const f16;
let b = lhs_p.as_ptr() as *const f16;
let c = dst_p.as_mut_ptr() as *mut f16;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
crate::mkl::hgemm(
transa,
transb,
/* m= */ n as i32,
/* n= */ m as i32,
/* k= */ k as i32,
/* alpha= */ f16::ONE,
/* a= */ a,
/* lda= */ lda,
/* b= */ b,
/* ldb= */ ldb,
/* beta= */ f16::ZERO,
/* c= */ c,
/* ldc= */ n as i32,
)
}
}
}
DType::F32 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
@ -428,7 +458,7 @@ impl Map2 for MatMul {
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
blas::sgemm(
crate::mkl::sgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
@ -449,7 +479,7 @@ impl Map2 for MatMul {
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
blas::dgemm(
crate::mkl::dgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,