Compare commits

..

2 Commits

Author SHA1 Message Date
7e49e0af96 Tmp for allocator. 2023-11-16 12:50:41 +01:00
181d2299b2 TMp. 2023-11-16 11:41:06 +01:00
9 changed files with 233 additions and 285 deletions

View File

@ -61,7 +61,10 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
dispatch = "0.2.0"
rustc-hash = "1.1"
[profile.release-with-debug]
inherits = "release"

View File

@ -30,6 +30,8 @@ safetensors = { workspace = true }
thiserror = { workspace = true }
yoke = { workspace = true }
zip = { workspace = true }
dispatch = { workspace = true, optional = true }
rustc-hash = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -41,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]

View File

@ -7,9 +7,10 @@ use candle_metal_kernels::Kernels;
use half::f16;
use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use std::sync::{Arc, RwLock};
use std::collections::HashMap;
use rustc_hash::FxHashMap;
use dispatch::{Queue, QueueAttribute};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -38,8 +39,9 @@ pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
queue : Queue,
kernels: Arc<candle_metal_kernels::Kernels>,
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
}
impl std::fmt::Debug for MetalDevice {
@ -61,10 +63,6 @@ impl MetalDevice {
self.registry_id()
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
@ -73,28 +71,10 @@ impl MetalDevice {
self.command_buffer.try_read().unwrap()
}
pub fn commit(&self) {
let mut old = self.command_buffer.try_write().unwrap();
match old.status(){
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => {
old.commit();
let command_buffer = self.command_queue.new_command_buffer().to_owned();
*old = command_buffer;
}
_ => {}
}
// self.command_buffer.replace_with(|_| command_buffer)
}
pub fn wait_until_completed(&self) {
let mut old = self.command_buffer.try_write().unwrap();
match old.status(){
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => {
old.commit();
old.wait_until_completed();
}
_ => {}
}
old.commit();
old.wait_until_completed();
let command_buffer = self.command_queue.new_command_buffer().to_owned();
*old = command_buffer;
// self.command_buffer.replace_with(|_| command_buffer)
@ -108,72 +88,41 @@ impl MetalDevice {
&self.device
}
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer >{
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
}
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer >{
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = element_count * dtype.size_in_bytes();
let mut buffers = self.buffers.try_write().unwrap();
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
let subbuffers = buffers.entry(size).or_insert(vec![]);
for sub in &mut *subbuffers{
if Arc::strong_count(sub) == 1{
if sub.retain_count() == 1{
return sub.clone();
// println!("{size } {:?}", sub.retain_count());
}
}
let new_buffer = self.device
.new_buffer(size as NSUInteger, option);
let new_buffer = Arc::new(new_buffer);
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
subbuffers.push(new_buffer.clone());
new_buffer
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
self._new_buffer(size, MTLResourceOptions::StorageModeManaged)
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
let tmp = self.device.new_buffer_with_data(
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
let option = metal::MTLResourceOptions::StorageModeManaged;
self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
core::mem::size_of_val(data) as NSUInteger,
metal::MTLResourceOptions::StorageModeManaged
);
let real = self._new_buffer(
core::mem::size_of_val(data) as NSUInteger,
metal::MTLResourceOptions::StorageModePrivate
);
{
let command = self.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.end_encoding();
}
real
}
pub fn new_matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), size: NSUInteger, type_id: u32, dtype:DType) -> Result<(Matrix, Arc<Buffer>)>{
let elem_count = (b * m * n ) as usize;
let out_buffer = self.new_buffer(elem_count, dtype);
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
0,
&result_descriptor,
option,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
Ok((result_matrix, out_buffer))
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
buffer: Arc<metal::Buffer>,
matrices: Arc<RwLock<HashMap<(NSUInteger, NSUInteger, NSUInteger, bool, NSUInteger, NSUInteger, u32), Matrix>>>,
buffer: metal::Buffer,
device: MetalDevice,
dtype: DType,
}
@ -195,11 +144,15 @@ impl BackendStorage for MetalStorage {
fn to_cpu_storage(&self) -> Result<CpuStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length());
let command_buffer = self.device.command_buffer();
let blit = command_buffer.new_blit_command_encoder();
{
let command = self.device.command_buffer();
let blit = command.new_blit_command_encoder();
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
drop(command_buffer);
}
self.device.wait_until_completed();
match self.dtype {
@ -234,7 +187,7 @@ impl BackendStorage for MetalStorage {
let el = shape.elem_count();
let dtype = self.dtype;
let buffer = device.new_buffer(el, self.dtype);
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_buffer();
if layout.is_contiguous() && layout.start_offset() == 0 {
let name = match self.dtype {
@ -249,7 +202,7 @@ impl BackendStorage for MetalStorage {
name,
el,
&self.buffer,
&buffer,
&mut buffer,
mul as f32,
add as f32,
)
@ -269,17 +222,17 @@ impl BackendStorage for MetalStorage {
&self.buffer,
layout.stride(),
layout.start_offset() * dtype.size_in_bytes(),
&buffer,
&mut buffer,
mul as f32,
add as f32,
)
.unwrap();
}
Ok(Self::new(
Ok(Self {
buffer,
device.clone(),
device: device.clone(),
dtype,
))
})
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
@ -293,9 +246,8 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
assert!(sum_dims.len() == 1);
assert!(sum_dims[0] == layout.shape().rank() - 1);
assert!(layout.stride()[sum_dims[0]] == 1);
// assert!(layout.is_contiguous());
// assert!(layout.start_offset() == 0);
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
@ -330,10 +282,7 @@ impl BackendStorage for MetalStorage {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
if dtype == DType::U32{
todo!("Implement this");
}
let buffer = device.new_buffer(dst_el, dtype);
let mut buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous(
&device.device,
@ -343,16 +292,15 @@ impl BackendStorage for MetalStorage {
src_el,
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
Ok(Self {
buffer,
device,
dtype,
))
})
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -363,7 +311,7 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype);
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer();
if layout.is_contiguous() {
let kernel_name = match (self.dtype, dtype) {
@ -379,7 +327,7 @@ impl BackendStorage for MetalStorage {
kernel_name,
el_count,
&self.buffer,
&buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
@ -398,26 +346,36 @@ impl BackendStorage for MetalStorage {
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
}
Ok(Self::new(
Ok(Self {
buffer,
device.clone(),
device: device.clone(),
dtype,
))
})
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer();
let metal = self.device.clone();
let mut cloned = buffer.clone();
let inbuffer = self.buffer.clone();
let ldims = layout.dims().to_owned();
let lstride = layout.stride().to_owned();
let loffset = layout.start_offset() * dtype.size_in_bytes();
if layout.is_contiguous() && layout.start_offset() == 0 {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
@ -455,11 +413,16 @@ impl BackendStorage for MetalStorage {
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&buffer,
&inbuffer,
&mut cloned,
)
.map_err(MetalError::from)?;
.unwrap();
// });
} else {
// self.device.queue.exec_async(move || {
let device = metal;
let command_buffer = device.command_buffer();
use candle_metal_kernels::unary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => strided::cos::FLOAT,
@ -495,23 +458,22 @@ impl BackendStorage for MetalStorage {
&command_buffer,
&device.kernels,
kernel_name,
layout.dims(),
&self.buffer,
layout.stride(),
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
&ldims,
&inbuffer,
&lstride,
loffset,
&mut cloned,
0,
)
.map_err(MetalError::from)?;
.unwrap();
// });
}
command_buffer.set_label("unary");
drop(command_buffer);
self.device.commit();
Ok(Self::new(
Ok(Self {
buffer,
device.clone(),
device: device.clone(),
dtype,
))
})
}
fn binary_impl<B: BinaryOpT>(
@ -524,7 +486,7 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let shape = lhs_l.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype);
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_buffer();
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
@ -558,7 +520,7 @@ impl BackendStorage for MetalStorage {
el_count,
&self.buffer,
&rhs.buffer,
&buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
@ -587,18 +549,15 @@ impl BackendStorage for MetalStorage {
&rhs.buffer,
rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
}
command_buffer.set_label("binary");
drop(command_buffer);
self.device.commit();
Ok(Self::new(
Ok(Self {
buffer,
device.clone(),
device: device.clone(),
dtype,
))
})
}
fn where_cond(
@ -614,7 +573,7 @@ impl BackendStorage for MetalStorage {
let dims = shape.dims();
let el = shape.elem_count();
let dtype = t.dtype;
let buffer = self.device.new_buffer(el, dtype);
let mut buffer = self.device.new_buffer(el, dtype);
let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_where_cond_strided(
&device.device,
@ -631,14 +590,14 @@ impl BackendStorage for MetalStorage {
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
Ok(Self {
buffer,
device,
dtype,
))
})
}
fn conv1d(
@ -724,7 +683,7 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_el * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype);
let mut buffer = device.new_buffer(dst_el, dtype);
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
@ -741,14 +700,14 @@ impl BackendStorage for MetalStorage {
dim,
&self.buffer,
&ids.buffer,
&buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(
Ok(Self {
buffer,
device.clone(),
device: device.clone(),
dtype,
))
})
}
fn index_add(
@ -771,8 +730,7 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
// let start = std::time::Instant::now();
use metal::mps::matrix::*;
let (type_id, size) = match self.dtype {
DType::F32 => (
@ -786,6 +744,7 @@ impl BackendStorage for MetalStorage {
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
};
let elem_count = b * m * n;
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
@ -816,73 +775,115 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
let left_matrix = self.matrix((b, m, k), transpose_left, size,
lhs_l.start_offset() as NSUInteger * size, type_id)?;
let right_matrix = rhs.matrix((b, k, n), transpose_right, size,
rhs_l.start_offset() as NSUInteger * size, type_id)?;
let (result_matrix, out_buffer) = self.device.new_matrix((b, m, n), size, type_id, self.dtype)?;
let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id)
} else {
MatrixDescriptor::init_single(m, k, k * size, type_id)
};
let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id)
} else {
MatrixDescriptor::init_single(k, n, n * size, type_id)
};
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let command_buffer = self.device.command_buffer();
for bi in 0..b {
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
(bi * stride_left + lhs_l.start_offset() as u64) * size,
&left_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(
&rhs.buffer,
(bi * stride_right + rhs_l.start_offset() as u64) * size,
&right_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let result_matrix = Matrix::init_with_buffer_descriptor(
&out_buffer,
bi * m * n * size,
&result_descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
// matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
}
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
command_buffer.set_label("matmul");
drop(command_buffer);
self.device.commit();
Ok(Self::new(
out_buffer,
self.device.clone(),
self.dtype(),
))
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer();
if src_l.is_contiguous(){
command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder();
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, self.buffer.length() - src_offset);
blit.end_encoding();
}else{
let src_shape = src_l.shape();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
let command_buffer = self.device.command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
@ -899,58 +900,26 @@ impl BackendStorage for MetalStorage {
&self.buffer,
src_l.stride(),
src_l.start_offset() * self.dtype.size_in_bytes(),
&dst.buffer,
&mut dst.buffer,
dst_offset * dst.dtype.size_in_bytes(),
)
.map_err(MetalError::from)?;
command_buffer.set_label("copy_strided");
}
drop(command_buffer);
self.device.commit();
Ok(())
}
}
impl MetalStorage {
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
let matrices = Arc::new(RwLock::new(HashMap::new()));
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
Self {
buffer,
device,
dtype,
matrices
}
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
fn matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), transpose: bool, size: NSUInteger, offset: NSUInteger, type_id: u32) -> Result<Matrix>{
let key = (b, m, n, transpose, size, offset, type_id);
let mut matrices = self.matrices.try_write().unwrap();
if let Some(matrix) = matrices.get(&key){
Ok(matrix.clone())
}else{
let descriptor = if transpose{
MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
} else {
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
};
let matrix = Matrix::init_with_buffer_descriptor(
&self.buffer,
offset,
&descriptor,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrices.insert(key, matrix.clone());
Ok(matrix)
}
}
}
impl BackendDevice for MetalDevice {
@ -975,12 +944,14 @@ impl BackendDevice for MetalDevice {
let command_queue = device.new_command_queue();
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
Ok(Self {
device,
command_queue,
command_buffer,
buffers,
queue,
kernels,
})
}
@ -1001,7 +972,11 @@ impl BackendDevice for MetalDevice {
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let buffer = self.new_buffer(shape.elem_count(), dtype);
Ok(MetalStorage::new(buffer, self.clone(), dtype))
Ok(MetalStorage {
buffer,
device: self.clone(),
dtype,
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
@ -1020,11 +995,11 @@ impl BackendDevice for MetalDevice {
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
};
Ok(Self::Storage::new(
buffer.into(),
self.clone(),
storage.dtype(),
))
Ok(Self::Storage {
buffer,
device: self.clone(),
dtype: storage.dtype(),
})
}
fn rand_uniform(

View File

@ -57,7 +57,6 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"

View File

@ -10,7 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -298,8 +298,11 @@ pub fn call_unary_contiguous(
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -320,7 +323,7 @@ pub fn call_unary_strided(
input: &Buffer,
strides: &[usize],
offset: usize,
output: &Buffer,
output: &mut Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
@ -358,7 +361,7 @@ pub fn call_binary_contiguous(
length: usize,
left: &Buffer,
right: &Buffer,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
@ -386,7 +389,7 @@ pub fn call_binary_strided(
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
@ -425,7 +428,7 @@ pub fn call_cast_contiguous(
kernel_name: &'static str,
length: usize,
input: &Buffer,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
@ -450,7 +453,7 @@ pub fn call_cast_strided(
input: &Buffer,
input_strides: &[usize],
input_offset: usize,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
@ -481,8 +484,7 @@ pub fn call_reduce_contiguous(
length: usize,
out_length: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
@ -490,7 +492,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
set_params!(encoder, (length, elements_to_sum, input, output));
let thread_group_count = MTLSize {
width: out_length as u64,
@ -523,7 +525,7 @@ pub fn call_last_softmax(
length: usize,
elements_to_sum: usize,
input: &Buffer,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
@ -564,7 +566,7 @@ pub fn call_affine(
name: &'static str,
size: usize,
input: &Buffer,
output: &Buffer,
output: &mut Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
@ -590,7 +592,7 @@ pub fn call_affine_strided(
input: &Buffer,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
output: &mut Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
@ -632,7 +634,7 @@ pub fn call_where_cond_strided(
(left_stride, left_offset): (&[usize], usize),
right: &Buffer,
(right_stride, right_offset): (&[usize], usize),
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
@ -675,7 +677,7 @@ pub fn call_index_select(
dim: usize,
input: &Buffer,
ids: &Buffer,
output: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let left_size: usize = shape[..dim].iter().product();
let right_size: usize = shape[dim + 1..].iter().product();
@ -750,7 +752,7 @@ mod tests {
name,
v.len(),
&input,
&output,
&mut output,
)
.unwrap();
command_buffer.commit();
@ -775,7 +777,7 @@ mod tests {
x.len(),
&left,
&right,
&output,
&mut output,
)
.unwrap();
command_buffer.commit();
@ -805,7 +807,7 @@ mod tests {
&input,
strides,
offset,
&output,
&mut output,
0,
)
.unwrap();
@ -943,7 +945,7 @@ mod tests {
name,
v.len(),
&input,
&output,
&mut output,
)
.unwrap();
command_buffer.commit();
@ -984,7 +986,7 @@ mod tests {
"affine_float",
size,
&input,
&output,
&mut output,
mul as f32,
add as f32,
)
@ -1021,7 +1023,7 @@ mod tests {
&input,
strides,
0,
&output,
&mut output,
mul as f32,
add as f32,
)
@ -1119,7 +1121,7 @@ mod tests {
dim,
&embeddings_buffer,
&ids_buffer,
&dst_buffer,
&mut dst_buffer,
)
.unwrap();
@ -1226,8 +1228,7 @@ mod tests {
v.len(),
out_length,
&input,
0,
&output,
&mut output,
)
.unwrap();
command_buffer.commit();
@ -1255,7 +1256,7 @@ mod tests {
v.len(),
last_dim,
&input,
&output,
&mut output,
)
.unwrap();
command_buffer.commit();
@ -1355,7 +1356,7 @@ mod tests {
(&left_stride, left_offset),
&right,
(&cond_stride, cond_offset),
&output,
&mut output,
)
.unwrap();
command_buffer.commit();

View File

@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
constant int THREADGROUP_SIZE = 1024;
constant int THREADGROUP_SIZE = 256;
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \

View File

@ -19,7 +19,6 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -30,4 +29,3 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -201,37 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::{BackendStorage};
let device = storage.device();
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = "softmax_float";
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {