Remove the dependency to blas and use mkl directly. (#125)

This commit is contained in:
Laurent Mazare
2023-07-10 15:52:03 +01:00
committed by GitHub
parent 221b1aff65
commit 548b1df7ea
4 changed files with 190 additions and 4 deletions

View File

@ -11,7 +11,6 @@ license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
blas = { version = "0.22.0", optional = true }
byteorder = "1.4.3"
candle-kernels = { path = "../candle-kernels", optional = true }
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
@ -22,6 +21,7 @@ cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vectorize-pack" }
half = { version = "2.3.1", features = ["num-traits"] }
intel-mkl-src = {version="0.8.1", optional=true, features = ["mkl-dynamic-lp64-iomp"]}
libc = { version = "0.2.147", optional = true }
memmap2 = "0.7.1"
num-traits = "0.2.15"
num_cpus = "1.15.0"
@ -35,4 +35,4 @@ anyhow = { version = "1", features = ["backtrace"] }
[features]
default = ["cuda"]
cuda = ["dep:cudarc", "dep:candle-kernels"]
mkl = ["dep:blas", "dep:intel-mkl-src"]
mkl = ["dep:libc", "dep:intel-mkl-src"]

View File

@ -416,6 +416,36 @@ impl Map2 for MatMul {
let mut dst = vec![T::zero(); b * m * n];
match T::DTYPE {
DType::F16 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let a = rhs_p.as_ptr() as *const f16;
let b = lhs_p.as_ptr() as *const f16;
let c = dst_p.as_mut_ptr() as *mut f16;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
crate::mkl::hgemm(
transa,
transb,
/* m= */ n as i32,
/* n= */ m as i32,
/* k= */ k as i32,
/* alpha= */ f16::ONE,
/* a= */ a,
/* lda= */ lda,
/* b= */ b,
/* ldb= */ ldb,
/* beta= */ f16::ZERO,
/* c= */ c,
/* ldc= */ n as i32,
)
}
}
}
DType::F32 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
@ -428,7 +458,7 @@ impl Map2 for MatMul {
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
blas::sgemm(
crate::mkl::sgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
@ -449,7 +479,7 @@ impl Map2 for MatMul {
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
blas::dgemm(
crate::mkl::dgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,

View File

@ -44,6 +44,8 @@ mod dtype;
mod dummy_cuda_backend;
mod error;
mod layout;
#[cfg(feature = "mkl")]
mod mkl;
mod npy;
mod op;
pub mod safetensors;

154
candle-core/src/mkl.rs Normal file
View File

@ -0,0 +1,154 @@
use libc::{c_char, c_double, c_float, c_int};
mod ffi {
use super::*;
extern "C" {
pub fn sgemm_(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_float,
a: *const c_float,
lda: *const c_int,
b: *const c_float,
ldb: *const c_int,
beta: *const c_float,
c: *mut c_float,
ldc: *const c_int,
);
pub fn dgemm_(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const c_double,
a: *const c_double,
lda: *const c_int,
b: *const c_double,
ldb: *const c_int,
beta: *const c_double,
c: *mut c_double,
ldc: *const c_int,
);
pub fn hgemm_(
transa: *const c_char,
transb: *const c_char,
m: *const c_int,
n: *const c_int,
k: *const c_int,
alpha: *const half::f16,
a: *const half::f16,
lda: *const c_int,
b: *const half::f16,
ldb: *const c_int,
beta: *const half::f16,
c: *mut half::f16,
ldc: *const c_int,
);
}
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn sgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f32,
a: &[f32],
lda: i32,
b: &[f32],
ldb: i32,
beta: f32,
c: &mut [f32],
ldc: i32,
) {
ffi::sgemm_(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn dgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: f64,
a: &[f64],
lda: i32,
b: &[f64],
ldb: i32,
beta: f64,
c: &mut [f64],
ldc: i32,
) {
ffi::dgemm_(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}
#[allow(clippy::too_many_arguments)]
#[inline]
pub unsafe fn hgemm(
transa: u8,
transb: u8,
m: i32,
n: i32,
k: i32,
alpha: half::f16,
a: &[half::f16],
lda: i32,
b: &[half::f16],
ldb: i32,
beta: half::f16,
c: &mut [half::f16],
ldc: i32,
) {
ffi::hgemm_(
&(transa as c_char),
&(transb as c_char),
&m,
&n,
&k,
&alpha,
a.as_ptr(),
&lda,
b.as_ptr(),
&ldb,
&beta,
c.as_mut_ptr(),
&ldc,
)
}