mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
TMp.
This commit is contained in:
@ -62,6 +62,7 @@ wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
dispatch = "0.2.0"
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
@ -30,6 +30,7 @@ safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
dispatch = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
|
||||
|
@ -8,6 +8,7 @@ use half::f16;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use dispatch::{Queue, QueueAttribute};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -36,6 +37,7 @@ pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
queue : Queue,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
}
|
||||
|
||||
@ -328,13 +330,23 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
|
||||
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
let buffer = device.new_buffer(el_count, dtype);
|
||||
let metal = self.device.clone();
|
||||
let mut cloned = buffer.clone();
|
||||
let inbuffer = self.buffer.clone();
|
||||
let ldims = layout.dims().to_owned();
|
||||
let lstride = layout.stride().to_owned();
|
||||
let loffset = layout.start_offset() * dtype.size_in_bytes();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
self.device.queue.exec_async(move || {
|
||||
let device = metal;
|
||||
let command_buffer = device.command_buffer();
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
@ -372,11 +384,15 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&mut buffer,
|
||||
&inbuffer,
|
||||
&mut cloned,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
.unwrap();
|
||||
});
|
||||
} else {
|
||||
self.device.queue.exec_async(move || {
|
||||
let device = metal;
|
||||
let command_buffer = device.command_buffer();
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
@ -412,15 +428,17 @@ impl BackendStorage for MetalStorage {
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
&ldims,
|
||||
&inbuffer,
|
||||
&lstride,
|
||||
loffset,
|
||||
&mut cloned,
|
||||
0,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -896,10 +914,12 @@ impl BackendDevice for MetalDevice {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
queue,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user