mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
2 Commits
metal4_arc
...
tmpm4
Author | SHA1 | Date | |
---|---|---|---|
7e49e0af96 | |||
181d2299b2 |
@ -61,7 +61,10 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
#metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../metal-rs", features = ["mps"] }
|
||||||
|
dispatch = "0.2.0"
|
||||||
|
rustc-hash = "1.1"
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -30,6 +30,8 @@ safetensors = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
dispatch = { workspace = true, optional = true }
|
||||||
|
rustc-hash = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -41,4 +43,4 @@ cuda = ["cudarc", "dep:candle-kernels"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:dispatch"]
|
||||||
|
@ -7,9 +7,10 @@ use candle_metal_kernels::Kernels;
|
|||||||
use half::f16;
|
use half::f16;
|
||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
use dispatch::{Queue, QueueAttribute};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -38,8 +39,9 @@ pub struct MetalDevice {
|
|||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||||
|
buffers: Arc<RwLock<FxHashMap<usize, Vec<Buffer>>>>,
|
||||||
|
queue : Queue,
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -61,10 +63,6 @@ impl MetalDevice {
|
|||||||
self.registry_id()
|
self.registry_id()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn metal_device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
&self.command_queue
|
&self.command_queue
|
||||||
}
|
}
|
||||||
@ -73,28 +71,10 @@ impl MetalDevice {
|
|||||||
self.command_buffer.try_read().unwrap()
|
self.command_buffer.try_read().unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn commit(&self) {
|
|
||||||
let mut old = self.command_buffer.try_write().unwrap();
|
|
||||||
match old.status(){
|
|
||||||
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => {
|
|
||||||
old.commit();
|
|
||||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*old = command_buffer;
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
// self.command_buffer.replace_with(|_| command_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) {
|
pub fn wait_until_completed(&self) {
|
||||||
let mut old = self.command_buffer.try_write().unwrap();
|
let mut old = self.command_buffer.try_write().unwrap();
|
||||||
match old.status(){
|
old.commit();
|
||||||
metal::MTLCommandBufferStatus::NotEnqueued | metal::MTLCommandBufferStatus::Enqueued => {
|
old.wait_until_completed();
|
||||||
old.commit();
|
|
||||||
old.wait_until_completed();
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
*old = command_buffer;
|
*old = command_buffer;
|
||||||
// self.command_buffer.replace_with(|_| command_buffer)
|
// self.command_buffer.replace_with(|_| command_buffer)
|
||||||
@ -108,72 +88,41 @@ impl MetalDevice {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer >{
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = element_count * dtype.size_in_bytes();
|
||||||
self._new_buffer(size, MTLResourceOptions::StorageModePrivate)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _new_buffer(&self, size: NSUInteger, option: MTLResourceOptions) -> Arc<Buffer >{
|
|
||||||
let mut buffers = self.buffers.try_write().unwrap();
|
let mut buffers = self.buffers.try_write().unwrap();
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
let subbuffers = buffers.entry(size).or_insert(vec![]);
|
||||||
|
|
||||||
for sub in &mut *subbuffers{
|
for sub in &mut *subbuffers{
|
||||||
if Arc::strong_count(sub) == 1{
|
if sub.retain_count() == 1{
|
||||||
return sub.clone();
|
return sub.clone();
|
||||||
|
// println!("{size } {:?}", sub.retain_count());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let new_buffer = self.device
|
let new_buffer = self.device
|
||||||
.new_buffer(size as NSUInteger, option);
|
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
subbuffers.push(new_buffer.clone());
|
||||||
new_buffer
|
new_buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
|
||||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
self.device
|
||||||
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Buffer {
|
||||||
let tmp = self.device.new_buffer_with_data(
|
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||||
|
self.device.new_buffer_with_data(
|
||||||
data.as_ptr() as *const core::ffi::c_void,
|
data.as_ptr() as *const core::ffi::c_void,
|
||||||
core::mem::size_of_val(data) as NSUInteger,
|
core::mem::size_of_val(data) as NSUInteger,
|
||||||
metal::MTLResourceOptions::StorageModeManaged
|
option,
|
||||||
);
|
|
||||||
let real = self._new_buffer(
|
|
||||||
core::mem::size_of_val(data) as NSUInteger,
|
|
||||||
metal::MTLResourceOptions::StorageModePrivate
|
|
||||||
);
|
|
||||||
{
|
|
||||||
let command = self.command_buffer();
|
|
||||||
let blit = command.new_blit_command_encoder();
|
|
||||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
|
||||||
blit.end_encoding();
|
|
||||||
|
|
||||||
}
|
|
||||||
real
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), size: NSUInteger, type_id: u32, dtype:DType) -> Result<(Matrix, Arc<Buffer>)>{
|
|
||||||
let elem_count = (b * m * n ) as usize;
|
|
||||||
let out_buffer = self.new_buffer(elem_count, dtype);
|
|
||||||
|
|
||||||
let result_descriptor = MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
|
||||||
let result_matrix = Matrix::init_with_buffer_descriptor(
|
|
||||||
&out_buffer,
|
|
||||||
0,
|
|
||||||
&result_descriptor,
|
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
|
||||||
})?;
|
|
||||||
Ok((result_matrix, out_buffer))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
buffer: Arc<metal::Buffer>,
|
buffer: metal::Buffer,
|
||||||
matrices: Arc<RwLock<HashMap<(NSUInteger, NSUInteger, NSUInteger, bool, NSUInteger, NSUInteger, u32), Matrix>>>,
|
|
||||||
device: MetalDevice,
|
device: MetalDevice,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@ -195,11 +144,15 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||||
let command_buffer = self.device.command_buffer();
|
{
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let command = self.device.command_buffer();
|
||||||
|
let blit = command.new_blit_command_encoder();
|
||||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
drop(command_buffer);
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
self.device.wait_until_completed();
|
self.device.wait_until_completed();
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
@ -234,7 +187,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype);
|
let mut buffer = device.new_buffer(el, self.dtype);
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
@ -249,7 +202,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
name,
|
name,
|
||||||
el,
|
el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&buffer,
|
&mut buffer,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -269,17 +222,17 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * dtype.size_in_bytes(),
|
layout.start_offset() * dtype.size_in_bytes(),
|
||||||
&buffer,
|
&mut buffer,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
Ok(Self::new(
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
@ -293,9 +246,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
assert!(sum_dims.len() == 1);
|
assert!(sum_dims.len() == 1);
|
||||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||||
assert!(layout.stride()[sum_dims[0]] == 1);
|
assert!(layout.is_contiguous());
|
||||||
// assert!(layout.is_contiguous());
|
assert!(layout.start_offset() == 0);
|
||||||
// assert!(layout.start_offset() == 0);
|
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
let src_stride = layout.stride();
|
let src_stride = layout.stride();
|
||||||
let src_dims = layout.shape().dims();
|
let src_dims = layout.shape().dims();
|
||||||
@ -330,10 +282,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
}
|
}
|
||||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
if dtype == DType::U32{
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
todo!("Implement this");
|
|
||||||
}
|
|
||||||
let buffer = device.new_buffer(dst_el, dtype);
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -343,16 +292,15 @@ impl BackendStorage for MetalStorage {
|
|||||||
src_el,
|
src_el,
|
||||||
dst_el,
|
dst_el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
&mut buffer,
|
||||||
&buffer,
|
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
Ok(Self::new(
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
@ -363,7 +311,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let device = self.device();
|
let device = self.device();
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
if layout.is_contiguous() {
|
if layout.is_contiguous() {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
@ -379,7 +327,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
@ -398,26 +346,36 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self::new(
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
|
|
||||||
|
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype);
|
let buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_buffer();
|
let metal = self.device.clone();
|
||||||
|
let mut cloned = buffer.clone();
|
||||||
|
let inbuffer = self.buffer.clone();
|
||||||
|
let ldims = layout.dims().to_owned();
|
||||||
|
let lstride = layout.stride().to_owned();
|
||||||
|
let loffset = layout.start_offset() * dtype.size_in_bytes();
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
|
// self.device.queue.exec_async(move || {
|
||||||
|
let device = metal;
|
||||||
|
let command_buffer = device.command_buffer();
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
@ -455,11 +413,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&inbuffer,
|
||||||
&buffer,
|
&mut cloned,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.unwrap();
|
||||||
|
// });
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
// self.device.queue.exec_async(move || {
|
||||||
|
let device = metal;
|
||||||
|
let command_buffer = device.command_buffer();
|
||||||
use candle_metal_kernels::unary::strided;
|
use candle_metal_kernels::unary::strided;
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||||
@ -495,23 +458,22 @@ impl BackendStorage for MetalStorage {
|
|||||||
&command_buffer,
|
&command_buffer,
|
||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
layout.dims(),
|
&ldims,
|
||||||
&self.buffer,
|
&inbuffer,
|
||||||
layout.stride(),
|
&lstride,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
loffset,
|
||||||
&buffer,
|
&mut cloned,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.unwrap();
|
||||||
|
// });
|
||||||
}
|
}
|
||||||
command_buffer.set_label("unary");
|
|
||||||
drop(command_buffer);
|
Ok(Self {
|
||||||
self.device.commit();
|
|
||||||
Ok(Self::new(
|
|
||||||
buffer,
|
buffer,
|
||||||
device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: BinaryOpT>(
|
fn binary_impl<B: BinaryOpT>(
|
||||||
@ -524,7 +486,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = lhs_l.shape();
|
let shape = lhs_l.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer();
|
||||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||||
@ -558,7 +520,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
@ -587,18 +549,15 @@ impl BackendStorage for MetalStorage {
|
|||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
rhs_l.stride(),
|
rhs_l.stride(),
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
command_buffer.set_label("binary");
|
Ok(Self {
|
||||||
drop(command_buffer);
|
|
||||||
self.device.commit();
|
|
||||||
Ok(Self::new(
|
|
||||||
buffer,
|
buffer,
|
||||||
device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn where_cond(
|
fn where_cond(
|
||||||
@ -614,7 +573,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = t.dtype;
|
let dtype = t.dtype;
|
||||||
let buffer = self.device.new_buffer(el, dtype);
|
let mut buffer = self.device.new_buffer(el, dtype);
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
candle_metal_kernels::call_where_cond_strided(
|
candle_metal_kernels::call_where_cond_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -631,14 +590,14 @@ impl BackendStorage for MetalStorage {
|
|||||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||||
&f.buffer,
|
&f.buffer,
|
||||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
@ -724,7 +683,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dst_el = ids_el * left_size * right_size;
|
let dst_el = ids_el * left_size * right_size;
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let buffer = device.new_buffer(dst_el, dtype);
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(DType::U32, DType::F16) => "is_u32_f16",
|
(DType::U32, DType::F16) => "is_u32_f16",
|
||||||
@ -741,14 +700,14 @@ impl BackendStorage for MetalStorage {
|
|||||||
dim,
|
dim,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&ids.buffer,
|
&ids.buffer,
|
||||||
&buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self::new(
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_add(
|
fn index_add(
|
||||||
@ -771,8 +730,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// Create descriptors
|
// Create descriptors
|
||||||
|
use metal::mps::matrix::*;
|
||||||
// let start = std::time::Instant::now();
|
|
||||||
|
|
||||||
let (type_id, size) = match self.dtype {
|
let (type_id, size) = match self.dtype {
|
||||||
DType::F32 => (
|
DType::F32 => (
|
||||||
@ -786,6 +744,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let elem_count = b * m * n;
|
||||||
|
|
||||||
let lhs_stride = lhs_l.stride();
|
let lhs_stride = lhs_l.stride();
|
||||||
let rhs_stride = rhs_l.stride();
|
let rhs_stride = rhs_l.stride();
|
||||||
@ -816,73 +775,115 @@ impl BackendStorage for MetalStorage {
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
|
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => m * k,
|
||||||
|
_ => Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?,
|
||||||
|
} as u64;
|
||||||
|
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
|
[stride] => stride,
|
||||||
|
[] => n * k,
|
||||||
|
_ => Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?,
|
||||||
|
} as u64;
|
||||||
|
|
||||||
let b = b as NSUInteger;
|
let b = b as NSUInteger;
|
||||||
let m = m as NSUInteger;
|
let m = m as NSUInteger;
|
||||||
let n = n as NSUInteger;
|
let n = n as NSUInteger;
|
||||||
let k = k as NSUInteger;
|
let k = k as NSUInteger;
|
||||||
|
|
||||||
let left_matrix = self.matrix((b, m, k), transpose_left, size,
|
let left_descriptor = if transpose_left {
|
||||||
lhs_l.start_offset() as NSUInteger * size, type_id)?;
|
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||||
let right_matrix = rhs.matrix((b, k, n), transpose_right, size,
|
} else {
|
||||||
rhs_l.start_offset() as NSUInteger * size, type_id)?;
|
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||||
let (result_matrix, out_buffer) = self.device.new_matrix((b, m, n), size, type_id, self.dtype)?;
|
};
|
||||||
|
let right_descriptor = if transpose_right {
|
||||||
|
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||||
|
} else {
|
||||||
|
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||||
|
};
|
||||||
|
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||||
|
|
||||||
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
|
for bi in 0..b {
|
||||||
|
// Create matrix objects
|
||||||
|
let left_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
|
&self.buffer,
|
||||||
|
(bi * stride_left + lhs_l.start_offset() as u64) * size,
|
||||||
|
&left_descriptor,
|
||||||
|
)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
let right_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
|
&rhs.buffer,
|
||||||
|
(bi * stride_right + rhs_l.start_offset() as u64) * size,
|
||||||
|
&right_descriptor,
|
||||||
|
)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let result_matrix = Matrix::init_with_buffer_descriptor(
|
||||||
|
&out_buffer,
|
||||||
|
bi * m * n * size,
|
||||||
|
&result_descriptor,
|
||||||
|
)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
let alpha = 1.0f64;
|
let alpha = 1.0f64;
|
||||||
let beta = 0.0f64;
|
let beta = 0.0f64;
|
||||||
// Create kernel
|
// Create kernel
|
||||||
let matrix_multiplication = MatrixMultiplication::init(
|
let matrix_multiplication = MatrixMultiplication::init(
|
||||||
&self.device,
|
&self.device,
|
||||||
transpose_left,
|
transpose_left,
|
||||||
transpose_right,
|
transpose_right,
|
||||||
m,
|
m,
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
alpha,
|
alpha,
|
||||||
beta,
|
beta,
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// matrix_multiplication.set_batch_size(b);
|
// Encode kernel to command buffer
|
||||||
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
|
&command_buffer,
|
||||||
|
&left_matrix,
|
||||||
|
&right_matrix,
|
||||||
|
&result_matrix,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Encode kernel to command buffer
|
Ok(Self {
|
||||||
matrix_multiplication.encode_to_command_buffer(
|
buffer: out_buffer,
|
||||||
&command_buffer,
|
device: self.device.clone(),
|
||||||
&left_matrix,
|
dtype: self.dtype(),
|
||||||
&right_matrix,
|
})
|
||||||
&result_matrix,
|
|
||||||
);
|
|
||||||
command_buffer.set_label("matmul");
|
|
||||||
drop(command_buffer);
|
|
||||||
self.device.commit();
|
|
||||||
|
|
||||||
Ok(Self::new(
|
|
||||||
out_buffer,
|
|
||||||
self.device.clone(),
|
|
||||||
self.dtype(),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
let command_buffer = self.device.command_buffer();
|
|
||||||
if src_l.is_contiguous(){
|
|
||||||
command_buffer.set_label("copy_contiguous");
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
|
||||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
|
||||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, self.buffer.length() - src_offset);
|
|
||||||
blit.end_encoding();
|
|
||||||
|
|
||||||
}else{
|
|
||||||
let src_shape = src_l.shape();
|
let src_shape = src_l.shape();
|
||||||
let el_count = src_shape.elem_count();
|
let el_count = src_shape.elem_count();
|
||||||
if el_count == 0 {
|
if el_count == 0 {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
let command_buffer = self.device.command_buffer();
|
||||||
let kernel_name = match self.dtype {
|
let kernel_name = match self.dtype {
|
||||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
@ -899,58 +900,26 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.buffer,
|
&self.buffer,
|
||||||
src_l.stride(),
|
src_l.stride(),
|
||||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&dst.buffer,
|
&mut dst.buffer,
|
||||||
dst_offset * dst.dtype.size_in_bytes(),
|
dst_offset * dst.dtype.size_in_bytes(),
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.set_label("copy_strided");
|
|
||||||
}
|
|
||||||
drop(command_buffer);
|
|
||||||
self.device.commit();
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MetalStorage {
|
impl MetalStorage {
|
||||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
||||||
let matrices = Arc::new(RwLock::new(HashMap::new()));
|
|
||||||
Self {
|
Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
matrices
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
pub fn buffer(&self) -> &Buffer {
|
||||||
&self.buffer
|
&self.buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matrix(&self, (b, m, n): (NSUInteger, NSUInteger, NSUInteger), transpose: bool, size: NSUInteger, offset: NSUInteger, type_id: u32) -> Result<Matrix>{
|
|
||||||
|
|
||||||
let key = (b, m, n, transpose, size, offset, type_id);
|
|
||||||
|
|
||||||
let mut matrices = self.matrices.try_write().unwrap();
|
|
||||||
if let Some(matrix) = matrices.get(&key){
|
|
||||||
Ok(matrix.clone())
|
|
||||||
}else{
|
|
||||||
let descriptor = if transpose{
|
|
||||||
MatrixDescriptor::init_multiple(n, m, b, m * size, m * n * size, type_id)
|
|
||||||
} else {
|
|
||||||
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id)
|
|
||||||
};
|
|
||||||
let matrix = Matrix::init_with_buffer_descriptor(
|
|
||||||
&self.buffer,
|
|
||||||
offset,
|
|
||||||
&descriptor,
|
|
||||||
)
|
|
||||||
.ok_or_else(|| {
|
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
|
||||||
})?;
|
|
||||||
matrices.insert(key, matrix.clone());
|
|
||||||
Ok(matrix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendDevice for MetalDevice {
|
impl BackendDevice for MetalDevice {
|
||||||
@ -975,12 +944,14 @@ impl BackendDevice for MetalDevice {
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
let queue = Queue::create("co.huggingface.candle", QueueAttribute::Serial);
|
||||||
|
let buffers = Arc::new(RwLock::new(FxHashMap::default()));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
buffers,
|
buffers,
|
||||||
|
queue,
|
||||||
kernels,
|
kernels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -1001,7 +972,11 @@ impl BackendDevice for MetalDevice {
|
|||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||||
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
let buffer = self.new_buffer(shape.elem_count(), dtype);
|
||||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
Ok(MetalStorage {
|
||||||
|
buffer,
|
||||||
|
device: self.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
@ -1020,11 +995,11 @@ impl BackendDevice for MetalDevice {
|
|||||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||||
};
|
};
|
||||||
Ok(Self::Storage::new(
|
Ok(Self::Storage {
|
||||||
buffer.into(),
|
buffer,
|
||||||
self.clone(),
|
device: self.clone(),
|
||||||
storage.dtype(),
|
dtype: storage.dtype(),
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_uniform(
|
fn rand_uniform(
|
||||||
|
@ -57,7 +57,6 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
|||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
|
@ -10,7 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
@ -298,8 +298,11 @@ pub fn call_unary_contiguous(
|
|||||||
kernel_name: unary::contiguous::Kernel,
|
kernel_name: unary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
|
// assert_eq!(input.length(), output.length());
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -320,7 +323,7 @@ pub fn call_unary_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
strides: &[usize],
|
strides: &[usize],
|
||||||
offset: usize,
|
offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
output_offset: usize,
|
output_offset: usize,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||||
@ -358,7 +361,7 @@ pub fn call_binary_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
left: &Buffer,
|
left: &Buffer,
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
|
|
||||||
@ -386,7 +389,7 @@ pub fn call_binary_strided(
|
|||||||
right_input: &Buffer,
|
right_input: &Buffer,
|
||||||
right_strides: &[usize],
|
right_strides: &[usize],
|
||||||
right_offset: usize,
|
right_offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||||
|
|
||||||
@ -425,7 +428,7 @@ pub fn call_cast_contiguous(
|
|||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
@ -450,7 +453,7 @@ pub fn call_cast_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_strides: &[usize],
|
input_strides: &[usize],
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
// assert_eq!(input.length(), output.length());
|
// assert_eq!(input.length(), output.length());
|
||||||
@ -481,8 +484,7 @@ pub fn call_reduce_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
out_length: usize,
|
out_length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_offset: usize,
|
output: &mut Buffer,
|
||||||
output: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
@ -490,7 +492,7 @@ pub fn call_reduce_contiguous(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
|
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width: out_length as u64,
|
width: out_length as u64,
|
||||||
@ -523,7 +525,7 @@ pub fn call_last_softmax(
|
|||||||
length: usize,
|
length: usize,
|
||||||
elements_to_sum: usize,
|
elements_to_sum: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -564,7 +566,7 @@ pub fn call_affine(
|
|||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -590,7 +592,7 @@ pub fn call_affine_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_stride: &[usize],
|
input_stride: &[usize],
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -632,7 +634,7 @@ pub fn call_where_cond_strided(
|
|||||||
(left_stride, left_offset): (&[usize], usize),
|
(left_stride, left_offset): (&[usize], usize),
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
(right_stride, right_offset): (&[usize], usize),
|
(right_stride, right_offset): (&[usize], usize),
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
|
|
||||||
@ -675,7 +677,7 @@ pub fn call_index_select(
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
ids: &Buffer,
|
ids: &Buffer,
|
||||||
output: &Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
@ -750,7 +752,7 @@ mod tests {
|
|||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -775,7 +777,7 @@ mod tests {
|
|||||||
x.len(),
|
x.len(),
|
||||||
&left,
|
&left,
|
||||||
&right,
|
&right,
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -805,7 +807,7 @@ mod tests {
|
|||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
offset,
|
offset,
|
||||||
&output,
|
&mut output,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -943,7 +945,7 @@ mod tests {
|
|||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -984,7 +986,7 @@ mod tests {
|
|||||||
"affine_float",
|
"affine_float",
|
||||||
size,
|
size,
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&mut output,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -1021,7 +1023,7 @@ mod tests {
|
|||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
0,
|
0,
|
||||||
&output,
|
&mut output,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -1119,7 +1121,7 @@ mod tests {
|
|||||||
dim,
|
dim,
|
||||||
&embeddings_buffer,
|
&embeddings_buffer,
|
||||||
&ids_buffer,
|
&ids_buffer,
|
||||||
&dst_buffer,
|
&mut dst_buffer,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1226,8 +1228,7 @@ mod tests {
|
|||||||
v.len(),
|
v.len(),
|
||||||
out_length,
|
out_length,
|
||||||
&input,
|
&input,
|
||||||
0,
|
&mut output,
|
||||||
&output,
|
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -1255,7 +1256,7 @@ mod tests {
|
|||||||
v.len(),
|
v.len(),
|
||||||
last_dim,
|
last_dim,
|
||||||
&input,
|
&input,
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -1355,7 +1356,7 @@ mod tests {
|
|||||||
(&left_stride, left_offset),
|
(&left_stride, left_offset),
|
||||||
&right,
|
&right,
|
||||||
(&cond_stride, cond_offset),
|
(&cond_stride, cond_offset),
|
||||||
&output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
|
@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
constant int THREADGROUP_SIZE = 1024;
|
constant int THREADGROUP_SIZE = 256;
|
||||||
|
|
||||||
# define REDUCE(FN, NAME, TYPENAME) \
|
# define REDUCE(FN, NAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
|
@ -19,7 +19,6 @@ num-traits = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -30,4 +29,3 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
|
||||||
|
@ -201,37 +201,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "metal")]
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
storage: &candle::MetalStorage,
|
|
||||||
layout: &Layout,
|
|
||||||
) -> Result<(candle::MetalStorage, Shape)> {
|
|
||||||
use candle::backend::{BackendStorage};
|
|
||||||
let device = storage.device();
|
|
||||||
let command_buffer = device.command_buffer();
|
|
||||||
let kernels = device.kernels();
|
|
||||||
let name = "softmax_float";
|
|
||||||
assert!(layout.is_contiguous());
|
|
||||||
assert!(layout.start_offset() == 0);
|
|
||||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
|
||||||
let elem_count = layout.shape().elem_count();
|
|
||||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
|
||||||
candle_metal_kernels::call_last_softmax(
|
|
||||||
device.metal_device(),
|
|
||||||
&command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
elem_count,
|
|
||||||
last_dim,
|
|
||||||
storage.buffer(),
|
|
||||||
&mut output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
|
||||||
Ok((newstorage, layout.shape().clone()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user