mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
TMp.
This commit is contained in:
@ -62,6 +62,7 @@ wav = "1.0.0"
|
|||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
dispatch = "0.2.0"
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -30,6 +30,7 @@ safetensors = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
dispatch = { workspace = true, optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -41,4 +42,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
|
||||||
|
@ -8,6 +8,7 @@ use half::f16;
|
|||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
use dispatch::{Queue, QueueAttribute};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -36,6 +37,7 @@ pub struct MetalDevice {
|
|||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||||
|
queue : Queue,
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -328,13 +330,23 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
|
|
||||||
|
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_buffer();
|
let metal = self.device.clone();
|
||||||
|
let mut cloned = buffer.clone();
|
||||||
|
let inbuffer = self.buffer.clone();
|
||||||
|
let ldims = layout.dims().to_owned();
|
||||||
|
let lstride = layout.stride().to_owned();
|
||||||
|
let loffset = layout.start_offset() * dtype.size_in_bytes();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
|
self.device.queue.exec_async(move || {
|
||||||
|
let device = metal;
|
||||||
|
let command_buffer = device.command_buffer();
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
@ -372,11 +384,15 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&inbuffer,
|
||||||
&mut buffer,
|
&mut cloned,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.unwrap();
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
|
self.device.queue.exec_async(move || {
|
||||||
|
let device = metal;
|
||||||
|
let command_buffer = device.command_buffer();
|
||||||
use candle_metal_kernels::unary::strided;
|
use candle_metal_kernels::unary::strided;
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||||
@ -412,15 +428,17 @@ impl BackendStorage for MetalStorage {
|
|||||||
&command_buffer,
|
&command_buffer,
|
||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
layout.dims(),
|
&ldims,
|
||||||
&self.buffer,
|
&inbuffer,
|
||||||
layout.stride(),
|
&lstride,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
loffset,
|
||||||
&mut buffer,
|
&mut cloned,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.unwrap();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
@ -896,10 +914,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
|
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
|
queue,
|
||||||
kernels,
|
kernels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user