Support dgemm in mkl matmul. (#122)

This commit is contained in:
Laurent Mazare
2023-07-10 15:02:37 +01:00
committed by GitHub
parent 71cd3745a9
commit 221b1aff65

View File

@ -265,7 +265,7 @@ impl Map2 for MatMul {
const OP: &'static str = "mat_mul"; const OP: &'static str = "mat_mul";
#[cfg(not(feature = "mkl"))] #[cfg(not(feature = "mkl"))]
fn f<T: 'static + num_traits::Num + Copy>( fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self, &self,
lhs: &[T], lhs: &[T],
lhs_l: &Layout, lhs_l: &Layout,
@ -350,7 +350,7 @@ impl Map2 for MatMul {
} }
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
fn f<T: 'static + num_traits::Num + Copy>( fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self, &self,
lhs: &[T], lhs: &[T],
lhs_l: &Layout, lhs_l: &Layout,
@ -415,24 +415,50 @@ impl Map2 for MatMul {
}; };
let mut dst = vec![T::zero(); b * m * n]; let mut dst = vec![T::zero(); b * m * n];
for step in 0..b { match T::DTYPE {
let lhs_p = &lhs[step * a_skip..]; DType::F32 => {
let rhs_p = &rhs[step * b_skip..]; for step in 0..b {
let dst_p = &mut dst[step * c_skip..]; let lhs_p = &lhs[step * a_skip..];
unsafe { let rhs_p = &rhs[step * b_skip..];
let a = rhs_p.as_ptr() as *const f32; let dst_p = &mut dst[step * c_skip..];
let b = lhs_p.as_ptr() as *const f32; unsafe {
let c = dst_p.as_mut_ptr() as *mut f32; let a = rhs_p.as_ptr() as *const f32;
let a = std::slice::from_raw_parts(a, a_skip); let b = lhs_p.as_ptr() as *const f32;
let b = std::slice::from_raw_parts(b, b_skip); let c = dst_p.as_mut_ptr() as *mut f32;
let c = std::slice::from_raw_parts_mut(c, c_skip); let a = std::slice::from_raw_parts(a, a_skip);
blas::sgemm( let b = std::slice::from_raw_parts(b, b_skip);
transa, transb, /* m= */ n as i32, /* n= */ m as i32, let c = std::slice::from_raw_parts_mut(c, c_skip);
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a, blas::sgemm(
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0., transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* c= */ c, /* ldc= */ n as i32, /* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
) /* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
)
}
}
} }
DType::F64 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let a = rhs_p.as_ptr() as *const f64;
let b = lhs_p.as_ptr() as *const f64;
let c = dst_p.as_mut_ptr() as *mut f64;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
blas::dgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
)
}
}
}
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?,
} }
Ok(dst) Ok(dst)
} }