Compare commits

..

2 Commits

Author SHA1 Message Date
7e49e0af96 Tmp for allocator. 2023-11-16 12:50:41 +01:00
181d2299b2 TMp. 2023-11-16 11:41:06 +01:00
9 changed files with 233 additions and 285 deletions

View File

@ -61,7 +61,10 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0" wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] } yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false } zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } #metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
[profile.release-with-debug] [profile.release-with-debug]
inherits = "release" inherits = "release"

View File

@ -30,6 +30,8 @@ safetensors = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
yoke = { workspace = true } yoke = { workspace = true }
zip = { workspace = true } zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -41,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"] cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"] mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"] accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"] metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]

View File

@ -7,9 +7,10 @@ use candle_metal_kernels::Kernels;
use half::f16; use half::f16;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::collections::HashMap; use std::collections::HashMap;
use rustc_hash::FxHashMap;
use dispatch::{Queue, QueueAttribute};
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -38,8 +39,9 @@ pub struct MetalDevice {
device: metal::Device, device: metal::Device,
command_queue: metal::CommandQueue, command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>, command_buffer: Arc<RwLock<metal::CommandBuffer>>,
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
} }
impl std::fmt::Debug for MetalDevice { impl std::fmt::Debug for MetalDevice {
@ -61,10 +63,6 @@ impl MetalDevice {
self.registry_id() self.registry_id()
} }
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue { pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue &self.command_queue
} }
@ -73,28 +71,10 @@ impl MetalDevice {
self.command_buffer.try_read().unwrap() self.command_buffer.try_read().unwrap()
} }
pub fn commit(&self) {
let mut old = self.command_buffer.try_write().unwrap();
match old.status(){
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => {
old.commit();
let command_buffer = self.command_queue.new_command_buffer().to_owned();
*old = command_buffer;
}
_ => {}
}
// self.command_buffer.replace_with(|_| command_buffer)
}
pub fn wait_until_completed(&self) { pub fn wait_until_completed(&self) {
let mut old = self.command_buffer.try_write().unwrap(); let mut old = self.command_buffer.try_write().unwrap();
match old.status(){ old.commit();
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => { old.wait_until_completed();
old.commit();
old.wait_until_completed();
}
_ => {}
}
let command_buffer = self.command_queue.new_command_buffer().to_owned(); let command_buffer = self.command_queue.new_command_buffer().to_owned();
*old = command_buffer; *old = command_buffer;
// self.command_buffer.replace_with(|_| command_buffer) // self.command_buffer.replace_with(|_| command_buffer)
@ -108,72 +88,41 @@ impl MetalDevice {
&self.device &self.device
} }
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer >{ pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as NSUInteger; let size = element_count * dtype.size_in_bytes();
self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
}
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer >{
let mut buffers = self.buffers.try_write().unwrap(); let mut buffers = self.buffers.try_write().unwrap();
let subbuffers = buffers.entry((size, option)).or_insert(vec![]); let subbuffers = buffers.entry(size).or_insert(vec![]);
for sub in &mut *subbuffers{ for sub in &mut *subbuffers{
if Arc::strong_count(sub) == 1{ if sub.retain_count() == 1{
return sub.clone(); return sub.clone();
// println!("{size } {:?}", sub.retain_count());
} }
} }
let new_buffer = self.device let new_buffer = self.device
.new_buffer(size as NSUInteger, option); .new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
let new_buffer = Arc::new(new_buffer);
subbuffers.push(new_buffer.clone()); subbuffers.push(new_buffer.clone());
new_buffer new_buffer
} }
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> { pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
self._new_buffer(size, MTLResourceOptions::StorageModeManaged) self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
} }
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> { pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let tmp = self.device.new_buffer_with_data( let option = metal::MTLResourceOptions::StorageModeManaged;
self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void, data.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(data) as NSUInteger, core::mem::size_of_val(data) as NSUInteger,
metal::MTLResourceOptions::StorageModeManaged option,
);
let real = self._new_buffer(
core::mem::size_of_val(data) as NSUInteger,
metal::MTLResourceOptions::StorageModePrivate
);
{
let command = self.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.end_encoding();
}
real
}
pub fn new_matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), size: NSUInteger, type_id: u32, dtype:DType) -> Result<(Matrix, Arc<Buffer>)>{
let elem_count = (b * m * n ) as usize;
let out_buffer = self.new_buffer(elem_count, dtype);
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
0,
&result_descriptor,
) )
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
Ok((result_matrix, out_buffer))
} }
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MetalStorage { pub struct MetalStorage {
buffer: Arc<metal::Buffer>, buffer: metal::Buffer,
matrices: Arc<RwLock<HashMap<(NSUInteger, NSUInteger, NSUInteger, bool, NSUInteger, NSUInteger, u32), Matrix>>>,
device: MetalDevice, device: MetalDevice,
dtype: DType, dtype: DType,
} }
@ -195,11 +144,15 @@ impl BackendStorage for MetalStorage {
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length()); let buffer = self.device.new_buffer_managed(self.buffer.length());
let command_buffer = self.device.command_buffer(); {
let blit = command_buffer.new_blit_command_encoder(); let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length()); blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding(); blit.end_encoding();
drop(command_buffer);
}
self.device.wait_until_completed(); self.device.wait_until_completed();
match self.dtype { match self.dtype {
@ -234,7 +187,7 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = self.dtype; let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype); let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype { let name = match self.dtype {
@ -249,7 +202,7 @@ impl BackendStorage for MetalStorage {
name, name,
el, el,
&self.buffer, &self.buffer,
&buffer, &mut buffer,
mul as f32, mul as f32,
add as f32, add as f32,
) )
@ -269,17 +222,17 @@ impl BackendStorage for MetalStorage {
&self.buffer, &self.buffer,
layout.stride(), layout.stride(),
layout.start_offset() * dtype.size_in_bytes(), layout.start_offset() * dtype.size_in_bytes(),
&buffer, &mut buffer,
mul as f32, mul as f32,
add as f32, add as f32,
) )
.unwrap(); .unwrap();
} }
Ok(Self::new( Ok(Self {
buffer, buffer,
device.clone(), device: device.clone(),
dtype, dtype,
)) })
} }
fn powf(&self, _: &Layout, _: f64) -> Result<Self> { fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
@ -293,9 +246,8 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
assert!(sum_dims.len() == 1); assert!(sum_dims.len() == 1);
assert!(sum_dims[0] == layout.shape().rank() - 1); assert!(sum_dims[0] == layout.shape().rank() - 1);
assert!(layout.stride()[sum_dims[0]] == 1); assert!(layout.is_contiguous());
// assert!(layout.is_contiguous()); assert!(layout.start_offset() == 0);
// assert!(layout.start_offset() == 0);
let device = self.device.clone(); let device = self.device.clone();
let src_stride = layout.stride(); let src_stride = layout.stride();
let src_dims = layout.shape().dims(); let src_dims = layout.shape().dims();
@ -330,10 +282,7 @@ impl BackendStorage for MetalStorage {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
} }
let dtype = if return_index { DType::U32 } else { self.dtype }; let dtype = if return_index { DType::U32 } else { self.dtype };
if dtype == DType::U32{ let mut buffer = device.new_buffer(dst_el, dtype);
todo!("Implement this");
}
let buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous( candle_metal_kernels::call_reduce_contiguous(
&device.device, &device.device,
@ -343,16 +292,15 @@ impl BackendStorage for MetalStorage {
src_el, src_el,
dst_el, dst_el,
&self.buffer, &self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(), &mut buffer,
&buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new( Ok(Self {
buffer, buffer,
device, device,
dtype, dtype,
)) })
} }
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -363,7 +311,7 @@ impl BackendStorage for MetalStorage {
let device = self.device(); let device = self.device();
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer();
if layout.is_contiguous() { if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) { let kernel_name = match (self.dtype, dtype) {
@ -379,7 +327,7 @@ impl BackendStorage for MetalStorage {
kernel_name, kernel_name,
el_count, el_count,
&self.buffer, &self.buffer,
&buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
@ -398,26 +346,36 @@ impl BackendStorage for MetalStorage {
&self.buffer, &self.buffer,
layout.stride(), layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(), layout.start_offset() * self.dtype.size_in_bytes(),
&buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
Ok(Self::new( Ok(Self {
buffer, buffer,
device.clone(), device: device.clone(),
dtype, dtype,
)) })
} }
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> { fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device(); let device = self.device();
let dtype = self.dtype; let dtype = self.dtype;
let shape = layout.shape(); let shape = layout.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype); let buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer(); let metal = self.device.clone();
let mut cloned = buffer.clone();
let inbuffer = self.buffer.clone();
let ldims = layout.dims().to_owned();
let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 { if layout.is_contiguous() && layout.start_offset() == 0 {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::contiguous; use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) { let kernel_name = match (B::KERNEL, dtype) {
@ -455,11 +413,16 @@ impl BackendStorage for MetalStorage {
&device.kernels, &device.kernels,
kernel_name, kernel_name,
el_count, el_count,
&self.buffer, &inbuffer,
&buffer, &mut cloned,
) )
.map_err(MetalError::from)?; .unwrap();
// });
} else { } else {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::strided; use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) { let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT, ("ucos", DType::F32) => strided::cos::FLOAT,
@ -495,23 +458,22 @@ impl BackendStorage for MetalStorage {
&command_buffer, &command_buffer,
&device.kernels, &device.kernels,
kernel_name, kernel_name,
layout.dims(), &ldims,
&self.buffer, &inbuffer,
layout.stride(), &lstride,
layout.start_offset() * self.dtype.size_in_bytes(), loffset,
&buffer, &mut cloned,
0, 0,
) )
.map_err(MetalError::from)?; .unwrap();
// });
} }
command_buffer.set_label("unary");
drop(command_buffer); Ok(Self {
self.device.commit();
Ok(Self::new(
buffer, buffer,
device.clone(), device: device.clone(),
dtype, dtype,
)) })
} }
fn binary_impl<B: BinaryOpT>( fn binary_impl<B: BinaryOpT>(
@ -524,7 +486,7 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype; let dtype = self.dtype;
let shape = lhs_l.shape(); let shape = lhs_l.shape();
let el_count = shape.elem_count(); let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer(); let command_buffer = device.command_buffer();
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
@ -558,7 +520,7 @@ impl BackendStorage for MetalStorage {
el_count, el_count,
&self.buffer, &self.buffer,
&rhs.buffer, &rhs.buffer,
&buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} else { } else {
@ -587,18 +549,15 @@ impl BackendStorage for MetalStorage {
&rhs.buffer, &rhs.buffer,
rhs_l.stride(), rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
command_buffer.set_label("binary"); Ok(Self {
drop(command_buffer);
self.device.commit();
Ok(Self::new(
buffer, buffer,
device.clone(), device: device.clone(),
dtype, dtype,
)) })
} }
fn where_cond( fn where_cond(
@ -614,7 +573,7 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims(); let dims = shape.dims();
let el = shape.elem_count(); let el = shape.elem_count();
let dtype = t.dtype; let dtype = t.dtype;
let buffer = self.device.new_buffer(el, dtype); let mut buffer = self.device.new_buffer(el, dtype);
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_where_cond_strided( candle_metal_kernels::call_where_cond_strided(
&device.device, &device.device,
@ -631,14 +590,14 @@ impl BackendStorage for MetalStorage {
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer, &f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new( Ok(Self {
buffer, buffer,
device, device,
dtype, dtype,
)) })
} }
fn conv1d( fn conv1d(
@ -724,7 +683,7 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_el * left_size * right_size; let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype; let dtype = self.dtype;
let device = self.device(); let device = self.device();
let buffer = device.new_buffer(dst_el, dtype); let mut buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) { let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32", (DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::F16) => "is_u32_f16",
@ -741,14 +700,14 @@ impl BackendStorage for MetalStorage {
dim, dim,
&self.buffer, &self.buffer,
&ids.buffer, &ids.buffer,
&buffer, &mut buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
Ok(Self::new( Ok(Self {
buffer, buffer,
device.clone(), device: device.clone(),
dtype, dtype,
)) })
} }
fn index_add( fn index_add(
@ -771,8 +730,7 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
// Create descriptors // Create descriptors
use metal::mps::matrix::*;
// let start = std::time::Instant::now();
let (type_id, size) = match self.dtype { let (type_id, size) = match self.dtype {
DType::F32 => ( DType::F32 => (
@ -786,6 +744,7 @@ impl BackendStorage for MetalStorage {
dtype => todo!("Dtype for matmul {dtype:?} is not supported"), dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
}; };
let elem_count = b * m * n;
let lhs_stride = lhs_l.stride(); let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride(); let rhs_stride = rhs_l.stride();
@ -816,73 +775,115 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k), mnk: (m, n, k),
})? })?
}; };
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let b = b as NSUInteger; let b = b as NSUInteger;
let m = m as NSUInteger; let m = m as NSUInteger;
let n = n as NSUInteger; let n = n as NSUInteger;
let k = k as NSUInteger; let k = k as NSUInteger;
let left_matrix = self.matrix((b, m, k), transpose_left, size, let left_descriptor = if transpose_left {
lhs_l.start_offset() as NSUInteger * size, type_id)?; MatrixDescriptor::init_single(k, m, m * size, type_id)
let right_matrix = rhs.matrix((b, k, n), transpose_right, size, } else {
rhs_l.start_offset() as NSUInteger * size, type_id)?; MatrixDescriptor::init_single(m, k, k * size, type_id)
let (result_matrix, out_buffer) = self.device.new_matrix((b, m, n), size, type_id, self.dtype)?; };
let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id)
} else {
MatrixDescriptor::init_single(k, n, n * size, type_id)
};
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
bi * m * n * size,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64; let alpha = 1.0f64;
let beta = 0.0f64; let beta = 0.0f64;
// Create kernel // Create kernel
let matrix_multiplication = MatrixMultiplication::init( let matrix_multiplication = MatrixMultiplication::init(
&self.device, &self.device,
transpose_left, transpose_left,
transpose_right, transpose_right,
m, m,
n, n,
k, k,
alpha, alpha,
beta, beta,
) )
.ok_or_else(|| { .ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string()) MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?; })?;
// matrix_multiplication.set_batch_size(b); // Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
}
// Encode kernel to command buffer Ok(Self {
matrix_multiplication.encode_to_command_buffer( buffer: out_buffer,
&command_buffer, device: self.device.clone(),
&left_matrix, dtype: self.dtype(),
&right_matrix, })
&result_matrix,
);
command_buffer.set_label("matmul");
drop(command_buffer);
self.device.commit();
Ok(Self::new(
out_buffer,
self.device.clone(),
self.dtype(),
))
} }
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer();
if src_l.is_contiguous(){
command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder();
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, self.buffer.length() - src_offset);
blit.end_encoding();
}else{
let src_shape = src_l.shape(); let src_shape = src_l.shape();
let el_count = src_shape.elem_count(); let el_count = src_shape.elem_count();
if el_count == 0 { if el_count == 0 {
return Ok(()); return Ok(());
} }
let command_buffer = self.device.command_buffer();
let kernel_name = match self.dtype { let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
@ -899,58 +900,26 @@ impl BackendStorage for MetalStorage {
&self.buffer, &self.buffer,
src_l.stride(), src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(), src_l.start_offset() * self.dtype.size_in_bytes(),
&dst.buffer, &mut dst.buffer,
dst_offset * dst.dtype.size_in_bytes(), dst_offset * dst.dtype.size_in_bytes(),
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
}
drop(command_buffer);
self.device.commit();
Ok(()) Ok(())
} }
} }
impl MetalStorage { impl MetalStorage {
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self { pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
let matrices = Arc::new(RwLock::new(HashMap::new()));
Self { Self {
buffer, buffer,
device, device,
dtype, dtype,
matrices
} }
} }
pub fn buffer(&self) -> &Buffer { pub fn buffer(&self) -> &Buffer {
&self.buffer &self.buffer
} }
fn matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), transpose: bool, size: NSUInteger, offset: NSUInteger, type_id: u32) -> Result<Matrix>{
let key = (b, m, n, transpose, size, offset, type_id);
let mut matrices = self.matrices.try_write().unwrap();
if let Some(matrix) = matrices.get(&key){
Ok(matrix.clone())
}else{
let descriptor = if transpose{
MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
} else {
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
};
let matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
offset,
&descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrices.insert(key, matrix.clone());
Ok(matrix)
}
}
} }
impl BackendDevice for MetalDevice { impl BackendDevice for MetalDevice {
@ -975,12 +944,14 @@ impl BackendDevice for MetalDevice {
let command_queue = device.new_command_queue(); let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned())); let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new())); let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
Ok(Self { Ok(Self {
device, device,
command_queue, command_queue,
command_buffer, command_buffer,
buffers, buffers,
queue,
kernels, kernels,
}) })
} }
@ -1001,7 +972,11 @@ impl BackendDevice for MetalDevice {
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let buffer = self.new_buffer(shape.elem_count(), dtype); let buffer = self.new_buffer(shape.elem_count(), dtype);
Ok(MetalStorage::new(buffer, self.clone(), dtype)) Ok(MetalStorage {
buffer,
device: self.clone(),
dtype,
})
} }
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> { fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
@ -1020,11 +995,11 @@ impl BackendDevice for MetalDevice {
CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
}; };
Ok(Self::Storage::new( Ok(Self::Storage {
buffer.into(), buffer,
self.clone(), device: self.clone(),
storage.dtype(), dtype: storage.dtype(),
)) })
} }
fn rand_uniform( fn rand_uniform(

View File

@ -57,7 +57,6 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"] onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]] [[example]]
name = "llama_multiprocess" name = "llama_multiprocess"

View File

@ -10,7 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } # metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
once_cell = "1.18.0" once_cell = "1.18.0"
thiserror = "1" thiserror = "1"
tracing = "0.1.37" tracing = "0.1.37"

View File

@ -298,8 +298,11 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel, kernel_name: unary::contiguous::Kernel,
length: usize, length: usize,
input: &Buffer, input: &Buffer,
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
@ -320,7 +323,7 @@ pub fn call_unary_strided(
input: &Buffer, input: &Buffer,
strides: &[usize], strides: &[usize],
offset: usize, offset: usize,
output: &Buffer, output: &mut Buffer,
output_offset: usize, output_offset: usize,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@ -358,7 +361,7 @@ pub fn call_binary_contiguous(
length: usize, length: usize,
left: &Buffer, left: &Buffer,
right: &Buffer, right: &Buffer,
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@ -386,7 +389,7 @@ pub fn call_binary_strided(
right_input: &Buffer, right_input: &Buffer,
right_strides: &[usize], right_strides: &[usize],
right_offset: usize, right_offset: usize,
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@ -425,7 +428,7 @@ pub fn call_cast_contiguous(
kernel_name: &'static str, kernel_name: &'static str,
length: usize, length: usize,
input: &Buffer, input: &Buffer,
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -450,7 +453,7 @@ pub fn call_cast_strided(
input: &Buffer, input: &Buffer,
input_strides: &[usize], input_strides: &[usize],
input_offset: usize, input_offset: usize,
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0); // println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length()); // assert_eq!(input.length(), output.length());
@ -481,8 +484,7 @@ pub fn call_reduce_contiguous(
length: usize, length: usize,
out_length: usize, out_length: usize,
input: &Buffer, input: &Buffer,
input_offset: usize, output: &mut Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
@ -490,7 +492,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output)); set_params!(encoder, (length, elements_to_sum, input, output));
let thread_group_count = MTLSize { let thread_group_count = MTLSize {
width: out_length as u64, width: out_length as u64,
@ -523,7 +525,7 @@ pub fn call_last_softmax(
length: usize, length: usize,
elements_to_sum: usize, elements_to_sum: usize,
input: &Buffer, input: &Buffer,
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
@ -564,7 +566,7 @@ pub fn call_affine(
name: &'static str, name: &'static str,
size: usize, size: usize,
input: &Buffer, input: &Buffer,
output: &Buffer, output: &mut Buffer,
mul: f32, mul: f32,
add: f32, add: f32,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
@ -590,7 +592,7 @@ pub fn call_affine_strided(
input: &Buffer, input: &Buffer,
input_stride: &[usize], input_stride: &[usize],
input_offset: usize, input_offset: usize,
output: &Buffer, output: &mut Buffer,
mul: f32, mul: f32,
add: f32, add: f32,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
@ -632,7 +634,7 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize), (left_stride, left_offset): (&[usize], usize),
right: &Buffer, right: &Buffer,
(right_stride, right_offset): (&[usize], usize), (right_stride, right_offset): (&[usize], usize),
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -675,7 +677,7 @@ pub fn call_index_select(
dim: usize, dim: usize,
input: &Buffer, input: &Buffer,
ids: &Buffer, ids: &Buffer,
output: &Buffer, output: &mut Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product(); let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product(); let right_size: usize = shape[dim + 1..].iter().product();
@ -750,7 +752,7 @@ mod tests {
name, name,
v.len(), v.len(),
&input, &input,
&output, &mut output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -775,7 +777,7 @@ mod tests {
x.len(), x.len(),
&left, &left,
&right, &right,
&output, &mut output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -805,7 +807,7 @@ mod tests {
&input, &input,
strides, strides,
offset, offset,
&output, &mut output,
0, 0,
) )
.unwrap(); .unwrap();
@ -943,7 +945,7 @@ mod tests {
name, name,
v.len(), v.len(),
&input, &input,
&output, &mut output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -984,7 +986,7 @@ mod tests {
"affine_float", "affine_float",
size, size,
&input, &input,
&output, &mut output,
mul as f32, mul as f32,
add as f32, add as f32,
) )
@ -1021,7 +1023,7 @@ mod tests {
&input, &input,
strides, strides,
0, 0,
&output, &mut output,
mul as f32, mul as f32,
add as f32, add as f32,
) )
@ -1119,7 +1121,7 @@ mod tests {
dim, dim,
&embeddings_buffer, &embeddings_buffer,
&ids_buffer, &ids_buffer,
&dst_buffer, &mut dst_buffer,
) )
.unwrap(); .unwrap();
@ -1226,8 +1228,7 @@ mod tests {
v.len(), v.len(),
out_length, out_length,
&input, &input,
0, &mut output,
&output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -1255,7 +1256,7 @@ mod tests {
v.len(), v.len(),
last_dim, last_dim,
&input, &input,
&output, &mut output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();
@ -1355,7 +1356,7 @@ mod tests {
(&left_stride, left_offset), (&left_stride, left_offset),
&right, &right,
(&cond_stride, cond_offset), (&cond_stride, cond_offset),
&output, &mut output,
) )
.unwrap(); .unwrap();
command_buffer.commit(); command_buffer.commit();

View File

@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
return strided_i; return strided_i;
} }
constant int THREADGROUP_SIZE = 1024; constant int THREADGROUP_SIZE = 256;
# define REDUCE(FN, NAME, TYPENAME) \ # define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \ kernel void NAME( \

View File

@ -19,7 +19,6 @@ num-traits = { workspace = true }
rayon = { workspace = true } rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -30,4 +29,3 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"] cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -201,37 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
}; };
Ok((dst, layout.shape().clone())) Ok((dst, layout.shape().clone()))
} }
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::{BackendStorage};
let device = storage.device();
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = "softmax_float";
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}
} }
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {