This commit is contained in:
Nicolas Patry
2023-11-16 11:41:06 +01:00
parent 2801541e5f
commit 181d2299b2
3 changed files with 34 additions and 12 deletions

View File

@ -62,6 +62,7 @@ wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
dispatch = "0.2.0"
[profile.release-with-debug]
inherits = "release"

View File

@ -30,6 +30,7 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
dispatch = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]

View File

@ -8,6 +8,7 @@ use half::f16;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock};
use dispatch::{Queue, QueueAttribute};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -36,6 +37,7 @@ pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>,
}
@ -328,13 +330,23 @@ impl BackendStorage for MetalStorage {
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer();
let buffer = device.new_buffer(el_count, dtype);
let metal = self.device.clone();
let mut cloned = buffer.clone();
let inbuffer = self.buffer.clone();
let ldims = layout.dims().to_owned();
let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 {
self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@ -372,11 +384,15 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
&inbuffer,
&mut cloned,
)
.map_err(MetalError::from)?;
.unwrap();
});
} else {
self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
@ -412,15 +428,17 @@ impl BackendStorage for MetalStorage {
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
&ldims,
&inbuffer,
&lstride,
loffset,
&mut cloned,
0,
)
.map_err(MetalError::from)?;
.unwrap();
});
}
Ok(Self {
buffer,
device: device.clone(),
@ -896,10 +914,12 @@ impl BackendDevice for MetalDevice {
let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new());
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
Ok(Self {
device,
command_queue,
command_buffer,
queue,
kernels,
})
}