This commit is contained in:
Nicolas Patry
2023-11-16 11:41:06 +01:00
parent 2801541e5f
commit 181d2299b2
3 changed files with 34 additions and 12 deletions

View File

@ -62,6 +62,7 @@ wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] } yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
dispatch = "0.2.0"
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

View File

@ -30,6 +30,7 @@ safetensors = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
yoke = { workspace = true } yoke = { workspace = true }
zip = { workspace = true } zip = { workspace = true }
dispatch = { workspace = true, optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"] cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"] mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"] accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"] metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]

View File

@ -8,6 +8,7 @@ use half::f16;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use dispatch::{Queue, QueueAttribute};
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -36,6 +37,7 @@ pub struct MetalDevice {
device: metal::Device, device: metal::Device,
command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>, command_buffer: Arc<RwLock<metal::CommandBuffer>>,
queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
} }
@ -328,13 +330,23 @@ impl BackendStorage for MetalStorage {
} }
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device(); let device = self.device();
let dtype = self.dtype; let dtype = self.dtype;
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype); let buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer(); let metal = self.device.clone();
let mut cloned = buffer.clone();
let inbuffer = self.buffer.clone();
let ldims = layout.dims().to_owned();
let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::contiguous; use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) { let kernel_name = match (B::KERNEL, dtype) {
@ -372,11 +384,15 @@ impl BackendStorage for MetalStorage {
&device.kernels, &device.kernels,
kernel_name, kernel_name,
el_count, el_count,
&self.buffer, &inbuffer,
&mut buffer, &mut cloned,
) )
.map_err(MetalError::from)?; .unwrap();
});
} else { } else {
self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::strided; use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) { let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT, ("ucos", DType::F32) => strided::cos::FLOAT,
@ -412,15 +428,17 @@ impl BackendStorage for MetalStorage {
&command_buffer, &command_buffer,
&device.kernels, &device.kernels,
kernel_name, kernel_name,
layout.dims(), &ldims,
&self.buffer, &inbuffer,
layout.stride(), &lstride,
layout.start_offset() * self.dtype.size_in_bytes(), loffset,
&mut buffer, &mut cloned,
0, 0,
) )
.map_err(MetalError::from)?; .unwrap();
});
} }
Ok(Self { Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -896,10 +914,12 @@ impl BackendDevice for MetalDevice {
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
Ok(Self { Ok(Self {
device, device,
command_queue, command_queue,
command_buffer, command_buffer,
queue,
kernels, kernels,
}) })
} }