Compare commits

..

2 Commits

4 changed files with 194 additions and 231 deletions

View File

@ -61,10 +61,8 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0" wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] } yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } # metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] } metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

View File

@ -30,8 +30,6 @@ safetensors = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
yoke = { workspace = true } yoke = { workspace = true }
zip = { workspace = true } zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -43,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"] cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"] mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"] accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"] metal = ["dep:metal", "dep:candle-metal-kernels"]

View File

@ -6,11 +6,8 @@ use candle_metal_kernels;
use candle_metal_kernels::Kernels; use candle_metal_kernels::Kernels;
use half::f16; use half::f16;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, HeapDescriptor, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use dispatch::{Queue, QueueAttribute};
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -38,9 +35,8 @@ impl From<String> for MetalError {
pub struct MetalDevice { pub struct MetalDevice {
device: metal::Device, device: metal::Device,
command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
heap: metal::Heap,
command_buffer: Arc<RwLock<metal::CommandBuffer>>, command_buffer: Arc<RwLock<metal::CommandBuffer>>,
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
} }
@ -68,15 +64,31 @@ impl MetalDevice {
} }
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> { pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
self.command_buffer.try_read().unwrap() self.command_buffer.read().unwrap()
} }
pub fn wait_until_completed(&self) { pub fn commit_wait_until_completed(&self) {
let mut old = self.command_buffer.try_write().unwrap(); let mut old = self.command_buffer.try_write().unwrap();
let status = old.status();
use metal::MTLCommandBufferStatus::{
Committed, Completed, Enqueued, Error, NotEnqueued, Scheduled,
};
// match old.status() {}
if old.status() == metal::MTLCommandBufferStatus::Completed {
return;
}
old.commit(); old.commit();
old.wait_until_completed(); old.wait_until_completed();
// let count = old.retain_count();
// println!("Count {count:?}");
let command_buffer = self.command_queue.new_command_buffer().to_owned(); let command_buffer = self.command_queue.new_command_buffer().to_owned();
*old = command_buffer; *old = command_buffer;
// let count = old.retain_count();
// // println!("Count after {count:?}");
// old.release();
// let count = old.retain_count();
// println!("Count after release {count:?}");
// self.command_buffer.replace_with(|_| command_buffer) // self.command_buffer.replace_with(|_| command_buffer)
} }
@ -89,34 +101,22 @@ impl MetalDevice {
} }
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = element_count * dtype.size_in_bytes(); let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
let mut buffers = self.buffers.try_write().unwrap(); // println!("Creating buffer {size}");
let subbuffers = buffers.entry(size).or_insert(vec![]); let buffer = self
.heap
for sub in &mut *subbuffers{ .new_buffer(size, MTLResourceOptions::StorageModeShared)
if sub.retain_count() == 1{ .expect("New buffer");
return sub.clone(); // println!("{:?}", self.heap.used_size());
// println!("{size } {:?}", sub.retain_count()); buffer
}
}
let new_buffer = self.device
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
subbuffers.push(new_buffer.clone());
new_buffer
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
} }
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer { pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let option = metal::MTLResourceOptions::StorageModeManaged; let size = core::mem::size_of_val(data) as NSUInteger;
self.device.new_buffer_with_data( let option = metal::MTLResourceOptions::StorageModeShared;
data.as_ptr() as *const core::ffi::c_void, // println!("Creating data buffer {size}");
core::mem::size_of_val(data) as NSUInteger, self.device
option, .new_buffer_with_data(data.as_ptr() as *const core::ffi::c_void, size, option)
)
} }
} }
@ -143,39 +143,29 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length()); self.device.commit_wait_until_completed();
{
let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed();
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8( DType::U8 => Ok(CpuStorage::U8(
buffer.read_to_vec(buffer.length() as usize), self.buffer.read_to_vec(self.buffer.length() as usize),
)), )),
DType::U32 => Ok(CpuStorage::U32( DType::U32 => Ok(CpuStorage::U32(
buffer.read_to_vec(buffer.length() as usize / 4), self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)), )),
DType::I64 => Ok(CpuStorage::I64( DType::I64 => Ok(CpuStorage::I64(
buffer.read_to_vec(buffer.length() as usize / 8), self.buffer.read_to_vec(self.buffer.length() as usize / 8),
)), )),
DType::F16 => Ok(CpuStorage::F16( DType::F16 => Ok(CpuStorage::F16(
buffer.read_to_vec(buffer.length() as usize / 2), self.buffer.read_to_vec(self.buffer.length() as usize / 2),
)), )),
DType::BF16 => Ok(CpuStorage::BF16( DType::BF16 => Ok(CpuStorage::BF16(
buffer.read_to_vec(buffer.length() as usize / 2), self.buffer.read_to_vec(self.buffer.length() as usize / 2),
)), )),
DType::F32 => Ok(CpuStorage::F32( DType::F32 => Ok(CpuStorage::F32(
buffer.read_to_vec(buffer.length() as usize / 4), self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)), )),
DType::F64 => Ok(CpuStorage::F64( DType::F64 => Ok(CpuStorage::F64(
buffer.read_to_vec(buffer.length() as usize / 8), self.buffer.read_to_vec(self.buffer.length() as usize / 8),
)), )),
} }
} }
@ -359,116 +349,101 @@ impl BackendStorage for MetalStorage {
} }
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device(); let device = self.device();
let dtype = self.dtype; let dtype = self.dtype;
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let metal = self.device.clone(); {
let mut cloned = buffer.clone(); let command_buffer = device.command_buffer();
let inbuffer = self.buffer.clone(); if layout.is_contiguous() && layout.start_offset() == 0 {
let ldims = layout.dims().to_owned(); use candle_metal_kernels::unary::contiguous;
let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) { let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT, ("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT, ("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT, ("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT, ("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT, ("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT,
("ugelu", DType::F32) => contiguous::gelu::FLOAT, ("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT, ("uerf", DType::F32) => contiguous::erf::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT, ("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF, ("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF, ("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF, ("usqr", DType::F16) => contiguous::sqr::HALF,
("usqrt", DType::F16) => contiguous::sqrt::HALF, ("usqrt", DType::F16) => contiguous::sqrt::HALF,
("uneg", DType::F16) => contiguous::neg::HALF, ("uneg", DType::F16) => contiguous::neg::HALF,
("uexp", DType::F16) => contiguous::exp::HALF, ("uexp", DType::F16) => contiguous::exp::HALF,
("ulog", DType::F16) => contiguous::log::HALF, ("ulog", DType::F16) => contiguous::log::HALF,
("ugelu", DType::F16) => contiguous::gelu::HALF, ("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF, ("uerf", DType::F16) => contiguous::erf::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF, ("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF, ("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF, ("uround", DType::F16) => contiguous::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"), (name, dtype) => todo!("Match {name} - {dtype:?}"),
}; };
candle_metal_kernels::call_unary_contiguous( candle_metal_kernels::call_unary_contiguous(
&device.device, &device.device,
&command_buffer, &command_buffer,
&device.kernels, &device.kernels,
kernel_name, kernel_name,
el_count, el_count,
&inbuffer, &self.buffer,
&mut cloned, &mut buffer,
) )
.unwrap(); .map_err(MetalError::from)?;
// }); } else {
use candle_metal_kernels::unary::strided;
} else { let kernel_name = match (B::KERNEL, dtype) {
// self.device.queue.exec_async(move || { ("ucos", DType::F32) => strided::cos::FLOAT,
let device = metal; ("usin", DType::F32) => strided::sin::FLOAT,
let command_buffer = device.command_buffer(); ("usqr", DType::F32) => strided::sqr::FLOAT,
use candle_metal_kernels::unary::strided; ("usqrt", DType::F32) => strided::sqrt::FLOAT,
let kernel_name = match (B::KERNEL, dtype) { ("uneg", DType::F32) => strided::neg::FLOAT,
("ucos", DType::F32) => strided::cos::FLOAT, ("uexp", DType::F32) => strided::exp::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT, ("ulog", DType::F32) => strided::log::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT, ("ugelu", DType::F32) => strided::gelu::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT, ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT, ("uerf", DType::F32) => strided::erf::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT, ("uceil", DType::F32) => strided::ceil::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT, ("uround", DType::F32) => strided::round::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, ("ucos", DType::F16) => strided::cos::HALF,
("uerf", DType::F32) => strided::erf::FLOAT, ("usin", DType::F16) => strided::sin::HALF,
("uceil", DType::F32) => strided::ceil::FLOAT, ("usqr", DType::F16) => strided::sqr::HALF,
("ufloor", DType::F32) => strided::floor::FLOAT, ("usqrt", DType::F16) => strided::sqrt::HALF,
("uround", DType::F32) => strided::round::FLOAT, ("uneg", DType::F16) => strided::neg::HALF,
("ucos", DType::F16) => strided::cos::HALF, ("uexp", DType::F16) => strided::exp::HALF,
("usin", DType::F16) => strided::sin::HALF, ("ulog", DType::F16) => strided::log::HALF,
("usqr", DType::F16) => strided::sqr::HALF, ("ugelu", DType::F16) => strided::gelu::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF, ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uneg", DType::F16) => strided::neg::HALF, ("uerf", DType::F16) => strided::erf::HALF,
("uexp", DType::F16) => strided::exp::HALF, ("uceil", DType::F16) => strided::ceil::HALF,
("ulog", DType::F16) => strided::log::HALF, ("ufloor", DType::F16) => strided::floor::HALF,
("ugelu", DType::F16) => strided::gelu::HALF, ("uround", DType::F16) => strided::round::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, (name, dtype) => todo!("Match {name} - {dtype:?}"),
("uerf", DType::F16) => strided::erf::HALF, };
("uceil", DType::F16) => strided::ceil::HALF, candle_metal_kernels::call_unary_strided(
("ufloor", DType::F16) => strided::floor::HALF, &device.device,
("uround", DType::F16) => strided::round::HALF, &command_buffer,
(name, dtype) => todo!("Match {name} - {dtype:?}"), &device.kernels,
}; kernel_name,
candle_metal_kernels::call_unary_strided( layout.dims(),
&device.device, &self.buffer,
&command_buffer, layout.stride(),
&device.kernels, layout.start_offset() * self.dtype.size_in_bytes(),
kernel_name, &mut buffer,
&ldims, 0,
&inbuffer, )
&lstride, .map_err(MetalError::from)?;
loffset, }
&mut cloned,
0,
)
.unwrap();
// });
} }
Ok(Self { Ok(Self {
buffer, buffer,
device: device.clone(), device: device.clone(),
@ -815,59 +790,61 @@ impl BackendStorage for MetalStorage {
let out_buffer = self.device.new_buffer(elem_count, self.dtype); let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer(); {
for bi in 0..b { let command_buffer = self.device.command_buffer();
// Create matrix objects for bi in 0..b {
let left_matrix = Matrix::init_with_buffer_descriptor( // Create matrix objects
&self.buffer, let left_matrix = Matrix::init_with_buffer_descriptor(
(bi * stride_left + lhs_l.start_offset() as u64) * size, &self.buffer,
&left_descriptor, (bi * stride_left + lhs_l.start_offset() as u64) * size,
) &left_descriptor,
.ok_or_else(|| { )
MetalError::from("Failed to create matrix multiplication kernel".to_string()) .ok_or_else(|| {
})?; MetalError::from("Failed to create matrix multiplication kernel".to_string())
let right_matrix = Matrix::init_with_buffer_descriptor( })?;
&rhs.buffer, let right_matrix = Matrix::init_with_buffer_descriptor(
(bi * stride_right + rhs_l.start_offset() as u64) * size, &rhs.buffer,
&right_descriptor, (bi * stride_right + rhs_l.start_offset() as u64) * size,
) &right_descriptor,
.ok_or_else(|| { )
MetalError::from("Failed to create matrix multiplication kernel".to_string()) .ok_or_else(|| {
})?; MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor( let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer, &out_buffer,
bi * m * n * size, bi * m * n * size,
&result_descriptor, &result_descriptor,
) )
.ok_or_else(|| { .ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string()) MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?; })?;
let alpha = 1.0f64; let alpha = 1.0f64;
let beta = 0.0f64; let beta = 0.0f64;
// Create kernel // Create kernel
let matrix_multiplication = MatrixMultiplication::init( let matrix_multiplication = MatrixMultiplication::init(
&self.device, &self.device,
transpose_left, transpose_left,
transpose_right, transpose_right,
m, m,
n, n,
k, k,
alpha, alpha,
beta, beta,
) )
.ok_or_else(|| { .ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string()) MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?; })?;
// Encode kernel to command buffer // Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer( matrix_multiplication.encode_to_command_buffer(
&command_buffer, &command_buffer,
&left_matrix, &left_matrix,
&right_matrix, &right_matrix,
&result_matrix, &result_matrix,
); );
}
} }
Ok(Self { Ok(Self {
@ -928,30 +905,22 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> { fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal); let device = metal::Device::all().swap_remove(ordinal);
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let descriptor = HeapDescriptor::new();
let mut size =
device.heap_buffer_size_and_align(100_000_000, MTLResourceOptions::StorageModeShared);
size.size += (size.size & (size.align - 1)) + size.align;
descriptor.set_size(size.size);
descriptor.set_storage_mode(metal::MTLStorageMode::Shared);
let heap = device.new_heap(&descriptor);
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
Ok(Self { Ok(Self {
device, device,
heap,
command_queue, command_queue,
command_buffer, command_buffer,
buffers,
queue,
kernels, kernels,
}) })
} }

View File

@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
input: &Buffer, input: &Buffer,
output: &mut Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -1136,6 +1133,7 @@ mod tests {
let device = Device::system_default().expect("no device found"); let device = Device::system_default().expect("no device found");
let options = CompileOptions::new(); let options = CompileOptions::new();
options.set_fast_math_enabled(true);
let library = device.new_library_with_source(INDEXING, &options).unwrap(); let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];