mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Support dgemm in mkl matmul. (#122)
This commit is contained in:
@ -265,7 +265,7 @@ impl Map2 for MatMul {
|
|||||||
const OP: &'static str = "mat_mul";
|
const OP: &'static str = "mat_mul";
|
||||||
|
|
||||||
#[cfg(not(feature = "mkl"))]
|
#[cfg(not(feature = "mkl"))]
|
||||||
fn f<T: 'static + num_traits::Num + Copy>(
|
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||||
&self,
|
&self,
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
@ -350,7 +350,7 @@ impl Map2 for MatMul {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
fn f<T: 'static + num_traits::Num + Copy>(
|
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
|
||||||
&self,
|
&self,
|
||||||
lhs: &[T],
|
lhs: &[T],
|
||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
@ -415,24 +415,50 @@ impl Map2 for MatMul {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let mut dst = vec![T::zero(); b * m * n];
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
for step in 0..b {
|
match T::DTYPE {
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
DType::F32 => {
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
for step in 0..b {
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
unsafe {
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
let a = rhs_p.as_ptr() as *const f32;
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
let b = lhs_p.as_ptr() as *const f32;
|
unsafe {
|
||||||
let c = dst_p.as_mut_ptr() as *mut f32;
|
let a = rhs_p.as_ptr() as *const f32;
|
||||||
let a = std::slice::from_raw_parts(a, a_skip);
|
let b = lhs_p.as_ptr() as *const f32;
|
||||||
let b = std::slice::from_raw_parts(b, b_skip);
|
let c = dst_p.as_mut_ptr() as *mut f32;
|
||||||
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
let a = std::slice::from_raw_parts(a, a_skip);
|
||||||
blas::sgemm(
|
let b = std::slice::from_raw_parts(b, b_skip);
|
||||||
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||||
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
blas::sgemm(
|
||||||
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb, /* beta= */ 0.,
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||||
/* c= */ c, /* ldc= */ n as i32,
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||||
)
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||||
|
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
|
unsafe {
|
||||||
|
let a = rhs_p.as_ptr() as *const f64;
|
||||||
|
let b = lhs_p.as_ptr() as *const f64;
|
||||||
|
let c = dst_p.as_mut_ptr() as *mut f64;
|
||||||
|
let a = std::slice::from_raw_parts(a, a_skip);
|
||||||
|
let b = std::slice::from_raw_parts(b, b_skip);
|
||||||
|
let c = std::slice::from_raw_parts_mut(c, c_skip);
|
||||||
|
blas::dgemm(
|
||||||
|
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
|
||||||
|
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
|
||||||
|
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
|
||||||
|
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul"))?,
|
||||||
}
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user