diff --git a/Cargo.toml b/Cargo.toml index 1a8145ba..fd3ec2e1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +dispatch = "0.2.0" [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 71ec99c4..c5a071e2 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -30,6 +30,7 @@ safetensors = { workspace = true } thiserror = { workspace = true } yoke = { workspace = true } zip = { workspace = true } +dispatch = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } @@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] -metal = ["dep:metal", "dep:candle-metal-kernels"] +metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"] diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 7145d42b..727749e8 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -8,6 +8,7 @@ use half::f16; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::sync::{Arc, RwLock}; +use dispatch::{Queue, QueueAttribute}; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -36,6 +37,7 @@ pub struct MetalDevice { device: metal::Device, command_queue: metal::CommandQueue, command_buffer: Arc>, + queue : Queue, kernels: Arc, } @@ -328,13 +330,23 @@ impl BackendStorage for MetalStorage { } fn unary_impl(&self, layout: &Layout) -> Result { + + let device = self.device(); let dtype = self.dtype; let shape = layout.shape(); let el_count = shape.elem_count(); - let mut buffer = device.new_buffer(el_count, dtype); - let command_buffer = device.command_buffer(); + let buffer = device.new_buffer(el_count, dtype); + let metal = self.device.clone(); + let mut cloned = buffer.clone(); + let inbuffer = self.buffer.clone(); + let ldims = layout.dims().to_owned(); + let lstride = layout.stride().to_owned(); + let loffset = layout.start_offset() * dtype.size_in_bytes(); if layout.is_contiguous() && layout.start_offset() == 0 { + self.device.queue.exec_async(move || { + let device = metal; + let command_buffer = device.command_buffer(); use candle_metal_kernels::unary::contiguous; let kernel_name = match (B::KERNEL, dtype) { @@ -372,11 +384,15 @@ impl BackendStorage for MetalStorage { &device.kernels, kernel_name, el_count, - &self.buffer, - &mut buffer, + &inbuffer, + &mut cloned, ) - .map_err(MetalError::from)?; + .unwrap(); + }); } else { + self.device.queue.exec_async(move || { + let device = metal; + let command_buffer = device.command_buffer(); use candle_metal_kernels::unary::strided; let kernel_name = match (B::KERNEL, dtype) { ("ucos", DType::F32) => strided::cos::FLOAT, @@ -412,15 +428,17 @@ impl BackendStorage for MetalStorage { &command_buffer, &device.kernels, kernel_name, - layout.dims(), - &self.buffer, - layout.stride(), - layout.start_offset() * self.dtype.size_in_bytes(), - &mut buffer, + &ldims, + &inbuffer, + &lstride, + loffset, + &mut cloned, 0, ) - .map_err(MetalError::from)?; + .unwrap(); + }); } + Ok(Self { buffer, device: device.clone(), @@ -896,10 +914,12 @@ impl BackendDevice for MetalDevice { let command_queue = device.new_command_queue(); let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let kernels = Arc::new(Kernels::new()); + let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial); Ok(Self { device, command_queue, command_buffer, + queue, kernels, }) }