From 221b1aff6594acd6d030c5131dba388590d1917f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 10 Jul 2023 15:02:37 +0100 Subject: [PATCH] Support dgemm in mkl matmul. (#122) --- candle-core/src/cpu_backend.rs | 64 ++++++++++++++++++++++++---------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 15982040..dd9dabc1 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -265,7 +265,7 @@ impl Map2 for MatMul { const OP: &'static str = "mat_mul"; #[cfg(not(feature = "mkl"))] - fn f( + fn f( &self, lhs: &[T], lhs_l: &Layout, @@ -350,7 +350,7 @@ impl Map2 for MatMul { } #[cfg(feature = "mkl")] - fn f( + fn f( &self, lhs: &[T], lhs_l: &Layout, @@ -415,24 +415,50 @@ impl Map2 for MatMul { }; let mut dst = vec![T::zero(); b * m * n]; - for step in 0..b { - let lhs_p = &lhs[step * a_skip..]; - let rhs_p = &rhs[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; - unsafe { - let a = rhs_p.as_ptr() as *const f32; - let b = lhs_p.as_ptr() as *const f32; - let c = dst_p.as_mut_ptr() as *mut f32; - let a = std::slice::from_raw_parts(a, a_skip); - let b = std::slice::from_raw_parts(b, b_skip); - let c = std::slice::from_raw_parts_mut(c, c_skip); - blas::sgemm( - transa, transb, /* m= */ n as i32, /* n= */ m as i32, - /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, - /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0., - /* c= */ c, /* ldc= */ n as i32, - ) + match T::DTYPE { + DType::F32 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f32; + let b = lhs_p.as_ptr() as *const f32; + let c = dst_p.as_mut_ptr() as *mut f32; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + blas::sgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } } + DType::F64 => { + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + let a = rhs_p.as_ptr() as *const f64; + let b = lhs_p.as_ptr() as *const f64; + let c = dst_p.as_mut_ptr() as *mut f64; + let a = std::slice::from_raw_parts(a, a_skip); + let b = std::slice::from_raw_parts(b, b_skip); + let c = std::slice::from_raw_parts_mut(c, c_skip); + blas::dgemm( + transa, transb, /* m= */ n as i32, /* n= */ m as i32, + /* k= */ k as i32, /* alpha= */ 1., /* a= */ a, + /* lda= */ lda, /* b= */ b, /* ldb= */ ldb, + /* beta= */ 0., /* c= */ c, /* ldc= */ n as i32, + ) + } + } + } + dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?, } Ok(dst) }