Hook the MLX matmul kernels in candle-core. (#2473)

This commit is contained in:
Laurent Mazare
2024-09-12 12:52:59 +01:00
committed by GitHub
parent 0cb0bd1dfa
commit 72d649058b
2 changed files with 38 additions and 0 deletions

View File

@ -70,6 +70,8 @@ pub struct MetalDevice {
pub(crate) buffers: AllocatedBuffers,
/// Seed for random number generation.
pub(crate) seed: Arc<Mutex<Buffer>>,
/// Whether to use the MLX matmul kernels instead of the MFA ones.
pub(crate) use_mlx_mm: bool,
}
impl std::fmt::Debug for MetalDevice {
@ -87,6 +89,10 @@ impl std::ops::Deref for MetalDevice {
}
impl MetalDevice {
pub fn set_use_mlx_mm(&mut self, use_mlx_mm: bool) {
self.use_mlx_mm = use_mlx_mm
}
pub fn id(&self) -> DeviceId {
self.id
}