mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Remove the dependency to blas and use mkl directly. (#125)
This commit is contained in:
@ -416,6 +416,36 @@ impl Map2 for MatMul {
|
||||
|
||||
let mut dst = vec![T::zero(); b * m * n];
|
||||
match T::DTYPE {
|
||||
DType::F16 => {
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
let rhs_p = &rhs[step * b_skip..];
|
||||
let dst_p = &mut dst[step * c_skip..];
|
||||
unsafe {
|
||||
let a = rhs_p.as_ptr() as *const f16;
|
||||
let b = lhs_p.as_ptr() as *const f16;
|
||||
let c = dst_p.as_mut_ptr() as *mut f16;
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
crate::mkl::hgemm(
|
||||
transa,
|
||||
transb,
|
||||
/* m= */ n as i32,
|
||||
/* n= */ m as i32,
|
||||
/* k= */ k as i32,
|
||||
/* alpha= */ f16::ONE,
|
||||
/* a= */ a,
|
||||
/* lda= */ lda,
|
||||
/* b= */ b,
|
||||
/* ldb= */ ldb,
|
||||
/* beta= */ f16::ZERO,
|
||||
/* c= */ c,
|
||||
/* ldc= */ n as i32,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
DType::F32 => {
|
||||
for step in 0..b {
|
||||
let lhs_p = &lhs[step * a_skip..];
|
||||
@ -428,7 +458,7 @@ impl Map2 for MatMul {
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
blas::sgemm(
|
||||
crate::mkl::sgemm(
|
||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||
@ -449,7 +479,7 @@ impl Map2 for MatMul {
|
||||
let a = std::slice::from_raw_parts(a, a_skip);
|
||||
let b = std::slice::from_raw_parts(b, b_skip);
|
||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||
blas::dgemm(
|
||||
crate::mkl::dgemm(
|
||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||
|
Reference in New Issue
Block a user