Support the Accelerate BLAS on macOS. (#325)

* Add the accelerate feature.

* Ffi tweaks.
This commit is contained in:
Laurent Mazare
2023-08-05 17:25:24 +01:00
committed by GitHub
parent 0b175fcbbd
commit b278834267
12 changed files with 241 additions and 9 deletions

View File

@ -974,7 +974,7 @@ impl MatMul {
impl Map2 for MatMul {
const OP: &'static str = "mat_mul";
#[cfg(not(feature = "mkl"))]
#[cfg(all(not(feature = "mkl"), not(feature = "accelerate")))]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
@ -1053,6 +1053,109 @@ impl Map2 for MatMul {
Ok(dst)
}
#[cfg(feature = "accelerate")]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,
lhs: &[T],
lhs_l: &Layout,
rhs: &[T],
rhs_l: &Layout,
) -> Result<Vec<T>> {
let (b, m, n, k) = self.0;
let lhs = &lhs[lhs_l.start_offset()..];
let rhs = &rhs[rhs_l.start_offset()..];
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rank = lhs_stride.len();
let a_skip: usize = match lhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?,
};
let b_skip: usize = match rhs_stride[..rank - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?,
};
let c_skip: usize = m * n;
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (lda, transa) = if rhs_m1 == 1 && rhs_m2 == n {
(n as i32, b'N')
} else if rhs_m1 == k && rhs_m2 == 1 {
(k as i32, b'T')
} else {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous rhs"))?
};
// The b tensor has dims batching, m, k (lhs)
let (ldb, transb) = if lhs_m1 == 1 && lhs_m2 == k {
(k as i32, b'N')
} else if lhs_m1 == m && lhs_m2 == 1 {
(m as i32, b'T')
} else {
Err(self.striding_error(lhs_l, rhs_l, "non-contiguous lhs"))?
};
let mut dst = vec![T::zero(); b * m * n];
match T::DTYPE {
DType::F16 => {
crate::bail!("the accelerate backend does not support f16 matmul")
}
DType::F32 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let a = rhs_p.as_ptr() as *const f32;
let b = lhs_p.as_ptr() as *const f32;
let c = dst_p.as_mut_ptr() as *mut f32;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
crate::accelerate::sgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
)
}
}
}
DType::F64 => {
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];
let dst_p = &mut dst[step * c_skip..];
unsafe {
let a = rhs_p.as_ptr() as *const f64;
let b = lhs_p.as_ptr() as *const f64;
let c = dst_p.as_mut_ptr() as *mut f64;
let a = std::slice::from_raw_parts(a, a_skip);
let b = std::slice::from_raw_parts(b, b_skip);
let c = std::slice::from_raw_parts_mut(c, c_skip);
crate::accelerate::dgemm(
transa, transb, /* m= */ n as i32, /* n= */ m as i32,
/* k= */ k as i32, /* alpha= */ 1., /* a= */ a,
/* lda= */ lda, /* b= */ b, /* ldb= */ ldb,
/* beta= */ 0., /* c= */ c, /* ldc= */ n as i32,
)
}
}
}
dtype => Err(Error::UnsupportedDTypeForOp(dtype, "matmul").bt())?,
}
Ok(dst)
}
#[cfg(feature = "mkl")]
fn f<T: 'static + WithDType + num_traits::Num + Copy>(
&self,