Compare commits

..

3 Commits

Author SHA1 Message Date
7e49e0af96 Tmp for allocator. 2023-11-16 12:50:41 +01:00
181d2299b2 TMp. 2023-11-16 11:41:06 +01:00
2801541e5f new_owned -> new()..to_owned(). 2023-11-16 11:07:56 +01:00
4 changed files with 231 additions and 194 deletions

View File

@ -61,8 +61,10 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
[profile.release-with-debug]
inherits = "release"

View File

@ -30,6 +30,8 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -41,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]

View File

@ -6,8 +6,11 @@ use candle_metal_kernels;
use candle_metal_kernels::Kernels;
use half::f16;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, HeapDescriptor, MTLResourceOptions, NSUInteger};
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::sync::{Arc, RwLock};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use dispatch::{Queue, QueueAttribute};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -35,8 +38,9 @@ impl From<String> for MetalError {
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
heap: metal::Heap,
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>,
}
@ -64,31 +68,15 @@ impl MetalDevice {
}
pub fn command_buffer(&self) -> std::sync::RwLockReadGuard<CommandBuffer> {
self.command_buffer.read().unwrap()
self.command_buffer.try_read().unwrap()
}
pub fn commit_wait_until_completed(&self) {
pub fn wait_until_completed(&self) {
let mut old = self.command_buffer.try_write().unwrap();
let status = old.status();
use metal::MTLCommandBufferStatus::{
Committed, Completed, Enqueued, Error, NotEnqueued, Scheduled,
};
// match old.status() {}
if old.status() == metal::MTLCommandBufferStatus::Completed {
return;
}
old.commit();
old.wait_until_completed();
// let count = old.retain_count();
// println!("Count {count:?}");
let command_buffer = self.command_queue.new_command_buffer().to_owned();
*old = command_buffer;
// let count = old.retain_count();
// // println!("Count after {count:?}");
// old.release();
// let count = old.retain_count();
// println!("Count after release {count:?}");
// self.command_buffer.replace_with(|_| command_buffer)
}
@ -101,22 +89,34 @@ impl MetalDevice {
}
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
// println!("Creating buffer {size}");
let buffer = self
.heap
.new_buffer(size, MTLResourceOptions::StorageModeShared)
.expect("New buffer");
// println!("{:?}", self.heap.used_size());
buffer
let size = element_count * dtype.size_in_bytes();
let mut buffers = self.buffers.try_write().unwrap();
let subbuffers = buffers.entry(size).or_insert(vec![]);
for sub in &mut *subbuffers{
if sub.retain_count() == 1{
return sub.clone();
// println!("{size } {:?}", sub.retain_count());
}
}
let new_buffer = self.device
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
subbuffers.push(new_buffer.clone());
new_buffer
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let size = core::mem::size_of_val(data) as NSUInteger;
let option = metal::MTLResourceOptions::StorageModeShared;
// println!("Creating data buffer {size}");
self.device
.new_buffer_with_data(data.as_ptr() as *const core::ffi::c_void, size, option)
let option = metal::MTLResourceOptions::StorageModeManaged;
self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(data) as NSUInteger,
option,
)
}
}
@ -143,29 +143,39 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
self.device.commit_wait_until_completed();
let buffer = self.device.new_buffer_managed(self.buffer.length());
{
let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed();
match self.dtype {
DType::U8 => Ok(CpuStorage::U8(
self.buffer.read_to_vec(self.buffer.length() as usize),
buffer.read_to_vec(buffer.length() as usize),
)),
DType::U32 => Ok(CpuStorage::U32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
buffer.read_to_vec(buffer.length() as usize / 4),
)),
DType::I64 => Ok(CpuStorage::I64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
buffer.read_to_vec(buffer.length() as usize / 8),
)),
DType::F16 => Ok(CpuStorage::F16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
buffer.read_to_vec(buffer.length() as usize / 2),
)),
DType::BF16 => Ok(CpuStorage::BF16(
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
buffer.read_to_vec(buffer.length() as usize / 2),
)),
DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
buffer.read_to_vec(buffer.length() as usize / 4),
)),
DType::F64 => Ok(CpuStorage::F64(
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
buffer.read_to_vec(buffer.length() as usize / 8),
)),
}
}
@ -349,101 +359,116 @@ impl BackendStorage for MetalStorage {
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
{
let command_buffer = device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
let buffer = device.new_buffer(el_count, dtype);
let metal = self.device.clone();
let mut cloned = buffer.clone();
let inbuffer = self.buffer.clone();
let ldims = layout.dims().to_owned();
let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("uneg", DType::F16) => contiguous::neg::HALF,
("uexp", DType::F16) => contiguous::exp::HALF,
("ulog", DType::F16) => contiguous::log::HALF,
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF,
("uneg", DType::F16) => strided::neg::HALF,
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
0,
)
.map_err(MetalError::from)?;
}
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
("ucos", DType::F16) => contiguous::cos::HALF,
("usin", DType::F16) => contiguous::sin::HALF,
("usqr", DType::F16) => contiguous::sqr::HALF,
("usqrt", DType::F16) => contiguous::sqrt::HALF,
("uneg", DType::F16) => contiguous::neg::HALF,
("uexp", DType::F16) => contiguous::exp::HALF,
("ulog", DType::F16) => contiguous::log::HALF,
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
("uround", DType::F16) => contiguous::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&inbuffer,
&mut cloned,
)
.unwrap();
// });
} else {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
("usin", DType::F32) => strided::sin::FLOAT,
("usqr", DType::F32) => strided::sqr::FLOAT,
("usqrt", DType::F32) => strided::sqrt::FLOAT,
("uneg", DType::F32) => strided::neg::FLOAT,
("uexp", DType::F32) => strided::exp::FLOAT,
("ulog", DType::F32) => strided::log::FLOAT,
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
("usqrt", DType::F16) => strided::sqrt::HALF,
("uneg", DType::F16) => strided::neg::HALF,
("uexp", DType::F16) => strided::exp::HALF,
("ulog", DType::F16) => strided::log::HALF,
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,
("uround", DType::F16) => strided::round::HALF,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
&ldims,
&inbuffer,
&lstride,
loffset,
&mut cloned,
0,
)
.unwrap();
// });
}
Ok(Self {
buffer,
device: device.clone(),
@ -790,61 +815,59 @@ impl BackendStorage for MetalStorage {
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
{
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
bi * m * n * size,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
bi * m * n * size,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
}
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
}
Ok(Self {
@ -905,22 +928,30 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
let descriptor = HeapDescriptor::new();
let mut size =
device.heap_buffer_size_and_align(100_000_000, MTLResourceOptions::StorageModeShared);
size.size += (size.size & (size.align - 1)) + size.align;
descriptor.set_size(size.size);
descriptor.set_storage_mode(metal::MTLStorageMode::Shared);
let heap = device.new_heap(&descriptor);
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new());
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
Ok(Self {
device,
heap,
command_queue,
command_buffer,
buffers,
queue,
kernels,
})
}

View File

@ -300,6 +300,9 @@ pub fn call_unary_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -1133,7 +1136,6 @@ mod tests {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
options.set_fast_math_enabled(true);
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];