mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
2 Commits
metal4_arc
...
tmpm4
Author | SHA1 | Date | |
---|---|---|---|
7e49e0af96 | |||
181d2299b2 |
@ -61,7 +61,10 @@ tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../metal-rs", features = ["mps"] }
|
||||
dispatch = "0.2.0"
|
||||
rustc-hash = "1.1"
|
||||
|
||||
[profile.release-with-debug]
|
||||
inherits = "release"
|
||||
|
@ -30,6 +30,8 @@ safetensors = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
yoke = { workspace = true }
|
||||
zip = { workspace = true }
|
||||
dispatch = { workspace = true, optional = true }
|
||||
rustc-hash = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -41,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
||||
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
|
||||
|
@ -7,9 +7,10 @@ use candle_metal_kernels::Kernels;
|
||||
use half::f16;
|
||||
use metal;
|
||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::collections::HashMap;
|
||||
use rustc_hash::FxHashMap;
|
||||
use dispatch::{Queue, QueueAttribute};
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -38,8 +39,9 @@ pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
command_queue: metal::CommandQueue,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
|
||||
queue : Queue,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MetalDevice {
|
||||
@ -61,10 +63,6 @@ impl MetalDevice {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
@ -73,28 +71,10 @@ impl MetalDevice {
|
||||
self.command_buffer.try_read().unwrap()
|
||||
}
|
||||
|
||||
pub fn commit(&self) {
|
||||
let mut old = self.command_buffer.try_write().unwrap();
|
||||
match old.status(){
|
||||
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => {
|
||||
old.commit();
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*old = command_buffer;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
// self.command_buffer.replace_with(|_| command_buffer)
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) {
|
||||
let mut old = self.command_buffer.try_write().unwrap();
|
||||
match old.status(){
|
||||
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => {
|
||||
old.commit();
|
||||
old.wait_until_completed();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
old.commit();
|
||||
old.wait_until_completed();
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*old = command_buffer;
|
||||
// self.command_buffer.replace_with(|_| command_buffer)
|
||||
@ -108,72 +88,41 @@ impl MetalDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer >{
|
||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
|
||||
}
|
||||
|
||||
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer >{
|
||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||
let size = element_count * dtype.size_in_bytes();
|
||||
let mut buffers = self.buffers.try_write().unwrap();
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
let subbuffers = buffers.entry(size).or_insert(vec![]);
|
||||
|
||||
for sub in &mut *subbuffers{
|
||||
if Arc::strong_count(sub) == 1{
|
||||
if sub.retain_count() == 1{
|
||||
return sub.clone();
|
||||
// println!("{size } {:?}", sub.retain_count());
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device
|
||||
.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
new_buffer
|
||||
}
|
||||
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
|
||||
self.device
|
||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
||||
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||
self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
core::mem::size_of_val(data) as NSUInteger,
|
||||
metal::MTLResourceOptions::StorageModeManaged
|
||||
);
|
||||
let real = self._new_buffer(
|
||||
core::mem::size_of_val(data) as NSUInteger,
|
||||
metal::MTLResourceOptions::StorageModePrivate
|
||||
);
|
||||
{
|
||||
let command = self.command_buffer();
|
||||
let blit = command.new_blit_command_encoder();
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.end_encoding();
|
||||
|
||||
}
|
||||
real
|
||||
}
|
||||
|
||||
pub fn new_matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), size: NSUInteger, type_id: u32, dtype:DType) -> Result<(Matrix, Arc<Buffer>)>{
|
||||
let elem_count = (b * m * n ) as usize;
|
||||
let out_buffer = self.new_buffer(elem_count, dtype);
|
||||
|
||||
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&out_buffer,
|
||||
0,
|
||||
&result_descriptor,
|
||||
option,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
Ok((result_matrix, out_buffer))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MetalStorage {
|
||||
buffer: Arc<metal::Buffer>,
|
||||
matrices: Arc<RwLock<HashMap<(NSUInteger, NSUInteger, NSUInteger, bool, NSUInteger, NSUInteger, u32), Matrix>>>,
|
||||
buffer: metal::Buffer,
|
||||
device: MetalDevice,
|
||||
dtype: DType,
|
||||
}
|
||||
@ -195,11 +144,15 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
{
|
||||
let command = self.device.command_buffer();
|
||||
let blit = command.new_blit_command_encoder();
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
drop(command_buffer);
|
||||
|
||||
}
|
||||
|
||||
|
||||
self.device.wait_until_completed();
|
||||
|
||||
match self.dtype {
|
||||
@ -234,7 +187,7 @@ impl BackendStorage for MetalStorage {
|
||||
let el = shape.elem_count();
|
||||
let dtype = self.dtype;
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype);
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let name = match self.dtype {
|
||||
@ -249,7 +202,7 @@ impl BackendStorage for MetalStorage {
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
@ -269,17 +222,17 @@ impl BackendStorage for MetalStorage {
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Ok(Self::new(
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device.clone(),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
@ -293,9 +246,8 @@ impl BackendStorage for MetalStorage {
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
assert!(sum_dims.len() == 1);
|
||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||
assert!(layout.stride()[sum_dims[0]] == 1);
|
||||
// assert!(layout.is_contiguous());
|
||||
// assert!(layout.start_offset() == 0);
|
||||
assert!(layout.is_contiguous());
|
||||
assert!(layout.start_offset() == 0);
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
@ -330,10 +282,7 @@ impl BackendStorage for MetalStorage {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
if dtype == DType::U32{
|
||||
todo!("Implement this");
|
||||
}
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
@ -343,16 +292,15 @@ impl BackendStorage for MetalStorage {
|
||||
src_el,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::new(
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
@ -363,7 +311,7 @@ impl BackendStorage for MetalStorage {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype);
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
@ -379,7 +327,7 @@ impl BackendStorage for MetalStorage {
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
@ -398,26 +346,36 @@ impl BackendStorage for MetalStorage {
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
Ok(Self::new(
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device.clone(),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
|
||||
|
||||
let device = self.device();
|
||||
let dtype = self.dtype;
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
let metal = self.device.clone();
|
||||
let mut cloned = buffer.clone();
|
||||
let inbuffer = self.buffer.clone();
|
||||
let ldims = layout.dims().to_owned();
|
||||
let lstride = layout.stride().to_owned();
|
||||
let loffset = layout.start_offset() * dtype.size_in_bytes();
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
// self.device.queue.exec_async(move || {
|
||||
let device = metal;
|
||||
let command_buffer = device.command_buffer();
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
@ -455,11 +413,16 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&buffer,
|
||||
&inbuffer,
|
||||
&mut cloned,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
.unwrap();
|
||||
// });
|
||||
|
||||
} else {
|
||||
// self.device.queue.exec_async(move || {
|
||||
let device = metal;
|
||||
let command_buffer = device.command_buffer();
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
@ -495,23 +458,22 @@ impl BackendStorage for MetalStorage {
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
&ldims,
|
||||
&inbuffer,
|
||||
&lstride,
|
||||
loffset,
|
||||
&mut cloned,
|
||||
0,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
.unwrap();
|
||||
// });
|
||||
}
|
||||
command_buffer.set_label("unary");
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
Ok(Self::new(
|
||||
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device.clone(),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
@ -524,7 +486,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype);
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_buffer();
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
@ -558,7 +520,7 @@ impl BackendStorage for MetalStorage {
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
@ -587,18 +549,15 @@ impl BackendStorage for MetalStorage {
|
||||
&rhs.buffer,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("binary");
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
Ok(Self::new(
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device.clone(),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn where_cond(
|
||||
@ -614,7 +573,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let dtype = t.dtype;
|
||||
let buffer = self.device.new_buffer(el, dtype);
|
||||
let mut buffer = self.device.new_buffer(el, dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
@ -631,14 +590,14 @@ impl BackendStorage for MetalStorage {
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -724,7 +683,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dst_el = ids_el * left_size * right_size;
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype);
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
@ -741,14 +700,14 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
&self.buffer,
|
||||
&ids.buffer,
|
||||
&buffer,
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device.clone(),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -771,8 +730,7 @@ impl BackendStorage for MetalStorage {
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// Create descriptors
|
||||
|
||||
// let start = std::time::Instant::now();
|
||||
use metal::mps::matrix::*;
|
||||
|
||||
let (type_id, size) = match self.dtype {
|
||||
DType::F32 => (
|
||||
@ -786,6 +744,7 @@ impl BackendStorage for MetalStorage {
|
||||
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
||||
};
|
||||
|
||||
let elem_count = b * m * n;
|
||||
|
||||
let lhs_stride = lhs_l.stride();
|
||||
let rhs_stride = rhs_l.stride();
|
||||
@ -816,73 +775,115 @@ impl BackendStorage for MetalStorage {
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
|
||||
let b = b as NSUInteger;
|
||||
let m = m as NSUInteger;
|
||||
let n = n as NSUInteger;
|
||||
let k = k as NSUInteger;
|
||||
|
||||
let left_matrix = self.matrix((b, m, k), transpose_left, size,
|
||||
lhs_l.start_offset() as NSUInteger * size, type_id)?;
|
||||
let right_matrix = rhs.matrix((b, k, n), transpose_right, size,
|
||||
rhs_l.start_offset() as NSUInteger * size, type_id)?;
|
||||
let (result_matrix, out_buffer) = self.device.new_matrix((b, m, n), size, type_id, self.dtype)?;
|
||||
let left_descriptor = if transpose_left {
|
||||
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||
};
|
||||
let right_descriptor = if transpose_right {
|
||||
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||
};
|
||||
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||
|
||||
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
for bi in 0..b {
|
||||
// Create matrix objects
|
||||
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&self.buffer,
|
||||
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
||||
&left_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&rhs.buffer,
|
||||
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
||||
&right_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||
&out_buffer,
|
||||
bi * m * n * size,
|
||||
&result_descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
|
||||
// matrix_multiplication.set_batch_size(b);
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
}
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
&command_buffer,
|
||||
&left_matrix,
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
command_buffer.set_label("matmul");
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
|
||||
Ok(Self::new(
|
||||
out_buffer,
|
||||
self.device.clone(),
|
||||
self.dtype(),
|
||||
))
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
if src_l.is_contiguous(){
|
||||
command_buffer.set_label("copy_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, self.buffer.length() - src_offset);
|
||||
blit.end_encoding();
|
||||
|
||||
}else{
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let command_buffer = self.device.command_buffer();
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
@ -899,58 +900,26 @@ impl BackendStorage for MetalStorage {
|
||||
&self.buffer,
|
||||
src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst.buffer,
|
||||
&mut dst.buffer,
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy_strided");
|
||||
}
|
||||
drop(command_buffer);
|
||||
self.device.commit();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
||||
let matrices = Arc::new(RwLock::new(HashMap::new()));
|
||||
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
device,
|
||||
dtype,
|
||||
matrices
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &Buffer {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
fn matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), transpose: bool, size: NSUInteger, offset: NSUInteger, type_id: u32) -> Result<Matrix>{
|
||||
|
||||
let key = (b, m, n, transpose, size, offset, type_id);
|
||||
|
||||
let mut matrices = self.matrices.try_write().unwrap();
|
||||
if let Some(matrix) = matrices.get(&key){
|
||||
Ok(matrix.clone())
|
||||
}else{
|
||||
let descriptor = if transpose{
|
||||
MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
|
||||
} else {
|
||||
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
|
||||
};
|
||||
let matrix = Matrix::init_with_buffer_descriptor(
|
||||
&self.buffer,
|
||||
offset,
|
||||
&descriptor,
|
||||
)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
matrices.insert(key, matrix.clone());
|
||||
Ok(matrix)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
@ -975,12 +944,14 @@ impl BackendDevice for MetalDevice {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
|
||||
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
command_buffer,
|
||||
buffers,
|
||||
queue,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
@ -1001,7 +972,11 @@ impl BackendDevice for MetalDevice {
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
Ok(MetalStorage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
@ -1020,11 +995,11 @@ impl BackendDevice for MetalDevice {
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
};
|
||||
Ok(Self::Storage::new(
|
||||
buffer.into(),
|
||||
self.clone(),
|
||||
storage.dtype(),
|
||||
))
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
dtype: storage.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
|
@ -57,7 +57,6 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
|
@ -10,7 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -298,8 +298,11 @@ pub fn call_unary_contiguous(
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -320,7 +323,7 @@ pub fn call_unary_strided(
|
||||
input: &Buffer,
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output_offset: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
@ -358,7 +361,7 @@ pub fn call_binary_contiguous(
|
||||
length: usize,
|
||||
left: &Buffer,
|
||||
right: &Buffer,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
@ -386,7 +389,7 @@ pub fn call_binary_strided(
|
||||
right_input: &Buffer,
|
||||
right_strides: &[usize],
|
||||
right_offset: usize,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||
|
||||
@ -425,7 +428,7 @@ pub fn call_cast_contiguous(
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
@ -450,7 +453,7 @@ pub fn call_cast_strided(
|
||||
input: &Buffer,
|
||||
input_strides: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
@ -481,8 +484,7 @@ pub fn call_reduce_contiguous(
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
@ -490,7 +492,7 @@ pub fn call_reduce_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -523,7 +525,7 @@ pub fn call_last_softmax(
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -564,7 +566,7 @@ pub fn call_affine(
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -590,7 +592,7 @@ pub fn call_affine_strided(
|
||||
input: &Buffer,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -632,7 +634,7 @@ pub fn call_where_cond_strided(
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
@ -675,7 +677,7 @@ pub fn call_index_select(
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
ids: &Buffer,
|
||||
output: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
@ -750,7 +752,7 @@ mod tests {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&output,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -775,7 +777,7 @@ mod tests {
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
&output,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -805,7 +807,7 @@ mod tests {
|
||||
&input,
|
||||
strides,
|
||||
offset,
|
||||
&output,
|
||||
&mut output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
@ -943,7 +945,7 @@ mod tests {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&output,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -984,7 +986,7 @@ mod tests {
|
||||
"affine_float",
|
||||
size,
|
||||
&input,
|
||||
&output,
|
||||
&mut output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
@ -1021,7 +1023,7 @@ mod tests {
|
||||
&input,
|
||||
strides,
|
||||
0,
|
||||
&output,
|
||||
&mut output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
@ -1119,7 +1121,7 @@ mod tests {
|
||||
dim,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&dst_buffer,
|
||||
&mut dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1226,8 +1228,7 @@ mod tests {
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -1255,7 +1256,7 @@ mod tests {
|
||||
v.len(),
|
||||
last_dim,
|
||||
&input,
|
||||
&output,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -1355,7 +1356,7 @@ mod tests {
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
&output,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
|
@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 1024;
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
|
@ -19,7 +19,6 @@ num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -30,4 +29,3 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||
|
@ -201,37 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::{BackendStorage};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let kernels = device.kernels();
|
||||
let name = "softmax_float";
|
||||
assert!(layout.is_contiguous());
|
||||
assert!(layout.start_offset() == 0);
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user