diff --git a/Cargo.toml b/Cargo.toml index 0fea0423..c37bd75b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ exclude = [ "candle-flash-attn", "candle-kernels", + "candle-metal-kernels", "candle-onnx", ] resolver = "2" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 5d5e70a3..592f5bdf 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,6 +13,7 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } @@ -40,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] -metal = ["dep:metal"] +metal = ["dep:metal", "dep:candle-metal-kernels"] diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs new file mode 100644 index 00000000..04a2c3dd --- /dev/null +++ b/candle-core/src/metal_backend.rs @@ -0,0 +1,821 @@ +use crate::backend::{BackendDevice, BackendStorage}; +use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Layout, Result, Shape}; +use candle_metal_kernels; +use candle_metal_kernels::{void_ptr, Kernels, Source}; +use core::mem; +use half::{bf16, f16}; +use metal; +use metal::mps::matrix::encode_gemm; +use metal::mps::Float32; +use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger}; +use std::sync::Arc; +use tracing::debug; + +/// Metal related errors +#[derive(thiserror::Error, Debug)] +pub enum MetalError { + #[error("{0}")] + Message(String), + #[error(transparent)] + KernelError(#[from] candle_metal_kernels::MetalKernelError), +} + +impl From for MetalError { + fn from(e: String) -> Self { + MetalError::Message(e) + } +} + +#[derive(Clone)] +pub struct MetalDevice { + device: metal::Device, + command_queue: metal::CommandQueue, + kernels: Arc, +} + +impl std::fmt::Debug for MetalDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetalDevice({:?})", self.device.registry_id()) + } +} + +impl std::ops::Deref for MetalDevice { + type Target = metal::DeviceRef; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl MetalDevice { + // pub fn metal_device(&self) -> &metal::DeviceRef { + // self.device.as_ref() + // } + + pub fn id(&self) -> u64 { + self.registry_id() + } + + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + pub fn kernels(&self) -> &Kernels { + &self.kernels + } + + pub fn device(&self) -> &metal::Device { + &self.device + } + + pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + let size = (element_count * dtype.size_in_bytes()) as u64; + // debug!("Allocate 1 - buffer size {size}"); + self.device + .new_buffer(size, MTLResourceOptions::StorageModeManaged) + } +} + +#[derive(Debug, Clone)] +pub struct MetalStorage { + buffer: metal::Buffer, + device: MetalDevice, + dtype: DType, +} + +impl BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Ok(self.clone()) + } + + fn dtype(&self) -> DType { + self.dtype + } + + fn device(&self) -> &Self::Device { + &self.device + } + + fn to_cpu_storage(&self) -> Result { + match self.dtype { + DType::F32 => Ok(CpuStorage::F32( + self.buffer.read_to_vec(self.buffer.length() as usize / 4), + )), + dtype => todo!("Unsupported dtype {dtype:?}"), + } + } + + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let dtype = self.dtype; + + assert!(layout.is_contiguous()); + assert_eq!(dtype, DType::F32); + + let mut buffer = device.new_buffer(el, self.dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + el, + &self.buffer, + &mut buffer, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + return Ok(Self { + buffer, + device: device.clone(), + dtype, + }); + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + todo!() + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + todo!() + } + + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { + // debug!("TODO reduce_op {op:?} {sum_dims:?}"); + assert!(sum_dims.len() == 1); + assert!(sum_dims[0] == layout.shape().rank() - 1); + assert!(layout.is_contiguous()); + let device = self.device.clone(); + let src_stride = layout.stride(); + let src_dims = layout.shape().dims(); + let src_el: usize = src_dims.iter().product(); + // Source dims and strides with the sum dims at the end. + let mut dims = vec![]; + let mut stride = vec![]; + let mut dst_el: usize = 1; + for (dim_idx, &d) in src_dims.iter().enumerate() { + if !sum_dims.contains(&dim_idx) { + dst_el *= d; + dims.push(d); + stride.push(src_stride[dim_idx]); + } + } + for &dim_idx in sum_dims.iter() { + dims.push(src_dims[dim_idx]); + stride.push(src_stride[dim_idx]); + } + + let el_to_sum_per_block = src_el / dst_el; + // The reduction loop requires the shared array to be properly initialized and for + // this we want the number of threads to be a power of two. + let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); + let (name, check_empty, return_index) = match (op, self.dtype) { + (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), + _ => todo!("Reduce op for non float"), + }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } + let dtype = if return_index { DType::U32 } else { self.dtype }; + let mut buffer = device.new_buffer(dst_el, dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_reduce_contiguous( + &device.device, + &command_buffer, + &device.kernels, + name, + src_el, + dst_el, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + + Ok(Self { + buffer, + device, + dtype, + }) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + todo!() + } + + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + let device = self.device(); + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32", + (left, right) => todo!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + todo!( + "TODO Implement the kernel calling cast {:?}-{:?}", + self.dtype, + dtype + ); + } + + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "cast {:?} - {:?} - {:?}", + dtype, + self.buffer.length(), + buffer.length() + ); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn unary_impl(&self, layout: &Layout) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + // TODO remove + // return Ok(Self { + // buffer, + // device: device.clone(), + // dtype, + // }); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + todo!("TODO Implement the kernel calling {}", B::KERNEL); + } + + let start = std::time::Instant::now(); + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "Unary {:?} - {:?} - {:?} - {:?}", + B::KERNEL, + start.elapsed(), + self.buffer.length(), + buffer.length() + ); + + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn binary_impl( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = lhs_l.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + use candle_metal_kernels::binary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("add", DType::F32) => contiguous::add::FLOAT, + ("badd", DType::F32) => contiguous::add::FLOAT, + ("sub", DType::F32) => contiguous::sub::FLOAT, + ("bsub", DType::F32) => contiguous::sub::FLOAT, + ("mul", DType::F32) => contiguous::mul::FLOAT, + ("bmul", DType::F32) => contiguous::mul::FLOAT, + ("div", DType::F32) => contiguous::div::FLOAT, + ("bdiv", DType::F32) => contiguous::div::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::binary::strided; + + let kernel_name = match (B::KERNEL, dtype) { + ("badd", DType::F32) => strided::add::FLOAT, + ("bsub", DType::F32) => strided::sub::FLOAT, + ("bmul", DType::F32) => strided::mul::FLOAT, + ("bdiv", DType::F32) => strided::div::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + &lhs_l.stride(), + lhs_l.start_offset(), + &rhs.buffer, + &rhs_l.stride(), + rhs_l.start_offset(), + &mut buffer, + ) + .map_err(MetalError::from)?; + } + + let start = std::time::Instant::now(); + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "Binary {:?} - {:?} - {:?} - {:?}", + B::KERNEL, + start.elapsed(), + self.buffer.length(), + buffer.length() + ); + + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn where_cond( + &self, + layout: &Layout, + t: &Self, + t_l: &Layout, + f: &Self, + f_l: &Layout, + ) -> Result { + let device = self.device.clone(); + let shape = t_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let dtype = t.dtype; + let mut buffer = self.device.new_buffer(el, dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_where_cond_strided( + &device.device, + &command_buffer, + &device.kernels, + "where_u8_f32", + &dims, + &self.buffer, + (layout.stride(), layout.start_offset()), + &t.buffer, + (&t_l.stride(), t_l.start_offset()), + &f.buffer, + (&f_l.stride(), f_l.start_offset()), + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + Ok(Self { + buffer, + device, + dtype, + }) + } + + fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv1D, + ) -> Result { + todo!() + } + + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv2D, + ) -> Result { + todo!() + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConvTranspose2D, + ) -> Result { + todo!() + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + todo!() + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + todo!() + } + + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + todo!() + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + todo!() + } + + fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + debug!( + "TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}", + self.buffer.length(), + ids.buffer.length(), + ); + let src = self; + let ids_shape = ids_l.shape(); + let ids_dims = ids_shape.dims(); + // let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; + // let src = match src_l.contiguous_offsets() { + // Some((o1, o2)) => src.slice(o1..o2), + // None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, + // }; + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_size = src_l.dims()[dim]; + let ids_dim_size = ids_shape.elem_count(); + let dst_el = ids_shape.elem_count() * left_size * right_size; + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + // todo!() + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + todo!() + } + + fn matmul( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let transpose_left = false; + let transpose_right = !rhs_l.is_contiguous(); + let alpha = 1.0; + let beta = 0.0; + self.matmul_generic( + rhs, + (b, m, n, k), + lhs_l, + rhs_l, + transpose_left, + transpose_right, + alpha, + beta, + ) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let src_shape = src_l.shape(); + let dims = src_shape.dims(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + if src_l.is_contiguous() { + let command_buffer = self.device.command_queue.new_command_buffer(); + let blip = command_buffer.new_blit_command_encoder(); + blip.copy_from_buffer( + &self.buffer, + src_l.start_offset() as u64, + &dst.buffer, + dst_offset as u64, + self.buffer.length(), + ); + } else { + let command_buffer = self.device.command_queue.new_command_buffer(); + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + dtype => todo!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + &src_l.stride(), + src_l.start_offset(), + &mut dst.buffer, + dst_offset, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + } + Ok(()) + } +} + +impl MetalStorage { + pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + Self { + buffer, + device, + dtype, + } + } + pub(crate) fn matmul_generic( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + transpose_left: bool, + transpose_right: bool, + alpha: f64, + beta: f64, + ) -> Result { + let elem_count = b * m * n; + match (self.dtype, rhs.dtype) { + (DType::F32, DType::F32) => { + let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); + if b != 1 { + debug!("TODO implement batched matmul for B={b}"); + // bail!("Didn't implemented strided matmul yet"); + return Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }); + } + if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { + debug!( + "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", + lhs_l.is_contiguous(), + rhs_l.is_contiguous(), + rhs_l + ); + return Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }); + } + + debug!("TODO GEMM"); + let command_buffer = self.device.command_queue.new_command_buffer(); + encode_gemm::( + &self.device, + &command_buffer, + transpose_left, + transpose_right, + &self.buffer, + &rhs.buffer, + &mut out_buffer, + m as NSUInteger, + n as NSUInteger, + k as NSUInteger, + alpha as f32, + beta as f32, + Some(b as NSUInteger), + ) + .map_err(MetalError::from)?; + + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + + Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }) + } + _ => todo!("Unimplemented matmul for this pair"), + } + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } +} + +impl BackendDevice for MetalDevice { + type Storage = MetalStorage; + + fn new(ordinal: usize) -> Result { + let device = metal::Device::all().swap_remove(ordinal); + + // let capture = metal::CaptureManager::shared(); + // let descriptor = metal::CaptureDescriptor::new(); + // descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + // descriptor.set_capture_device(&device); + // let mut dir = std::env::current_dir()?; + // dir.push("out.gputrace"); + // descriptor.set_output_url(dir); + + // capture + // .start_capture(&descriptor) + // .map_err(MetalError::from)?; + let command_queue = device.new_command_queue(); + // let command_buffer = _command_queue.new_owned_command_buffer(); + let kernels = Arc::new(Kernels::new()); + Ok(Self { + device, + command_queue, + // command_buffer, + kernels, + }) + } + + fn set_seed(&self, _seed: u64) -> Result<()> { + todo!("set_seed") + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Metal + } + + fn same_device(&self, rhs: &Self) -> bool { + self.device.registry_id() == rhs.device.registry_id() + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let option = metal::MTLResourceOptions::StorageModeManaged; + let buffer = match storage { + CpuStorage::U8(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::U32(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::I64(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::BF16(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::F16(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::F32(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::F64(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + }; + // debug!("Allocate 2 - buffer size {}", buffer.length()); + Ok(Self::Storage { + buffer, + device: self.clone(), + dtype: storage.dtype(), + }) + } + + fn rand_uniform( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + stddev: f64, + ) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn rand_normal( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + stddev: f64, + ) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) + } +} diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml new file mode 100644 index 00000000..ff5ede1a --- /dev/null +++ b/candle-metal-kernels/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "candle-metal-kernels" +version = "0.3.0" +edition = "2021" + +description = "CUDA kernels for Candle" +repository = "https://github.com/huggingface/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" + +[dependencies] +metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +once_cell = "1.18.0" +thiserror = "1" +tracing = "0.1.37" + +[dev-dependencies] +half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } diff --git a/candle-metal-kernels/README.md b/candle-metal-kernels/README.md new file mode 100644 index 00000000..ec923e9a --- /dev/null +++ b/candle-metal-kernels/README.md @@ -0,0 +1,3 @@ +# candle-metal-kernels + +This crate contains Metal kernels used from candle. \ No newline at end of file diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal new file mode 100644 index 00000000..c388c04e --- /dev/null +++ b/candle-metal-kernels/src/affine.metal @@ -0,0 +1,46 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define AFFINE(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const TYPENAME m = TYPENAME(mul); \ + const TYPENAME a = TYPENAME(add); \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = input[i] * m + a; \ + } \ +} \ + +AFFINE(affine_float, float) +AFFINE(affine_half, half) + + +#if __METAL_VERSION__ >= 310 +AFFINE(affine_bfloat, bfloat); +#endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal new file mode 100644 index 00000000..cfd34416 --- /dev/null +++ b/candle-metal-kernels/src/binary.metal @@ -0,0 +1,78 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + TYPENAME x = left[i]; \ + TYPENAME y = right[i]; \ + output[i] = OUT_TYPENAME(FN); \ + } \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *left_strides, \ + constant size_t *right_strides, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \ + TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \ + output[i] = OUT_TYPENAME(FN); \ + } \ +} + +#define BINARY_OP(FN, NAME) \ +BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ +BINARY(FN, half, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_BINARY_OP(FN, NAME) \ +BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +BINARY_OP(x + y, add) +BINARY_OP(x - y, sub) +BINARY_OP(x * y, mul) +BINARY_OP(x / y, div) + +#if __METAL_VERSION__ >= 310 +BFLOAT_BINARY_OP(x + y, add) +BFLOAT_BINARY_OP(x - y, sub) +BFLOAT_BINARY_OP(x * y, mul) +BFLOAT_BINARY_OP(x / y, div) +#endif diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal new file mode 100644 index 00000000..52e63662 --- /dev/null +++ b/candle-metal-kernels/src/cast.metal @@ -0,0 +1,58 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +using namespace metal; + +#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = RIGHT_TYPENAME(input[i]); \ + } \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + } \ +} + + +CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) + +#if __METAL_VERSION__ >= 310 +#endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal new file mode 100644 index 00000000..528c109d --- /dev/null +++ b/candle-metal-kernels/src/indexing.metal @@ -0,0 +1,75 @@ +#include +using namespace metal; + +template +void index_add( + device I *ids [[buffer(0)]], + device T *inp [[buffer(1)]], + device T *out [[buffer(2)]], + + constant uint &ids_dim_size, + constant uint &left_size, + constant uint &dst_dim_size, + constant uint &right_size, + + uint threadgroup_size [[threads_per_threadgroup]], + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint thread_index [[thread_index_in_threadgroup]] +) { + + const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size); + if (gid >= left_size * right_size) { + return; + } + + const uint i = gid; + const uint pre = i / right_size; + const uint post = i % right_size; + + for (uint j = 0; j < ids_dim_size; j++) { + const uint idx = ids[j]; + const uint src_i = (pre * ids_dim_size + j) * right_size + post; + const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } +} + +#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + device INDEX_TYPENAME *ids [[buffer(0)]], \ + device TYPENAME *inp [[buffer(1)]], \ + device TYPENAME *out [[buffer(2)]], \ + constant uint &ids_dim_size, \ + constant uint &left_size, \ + constant uint &dst_dim_size, \ + constant uint &right_size, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ + + + +#if __METAL_VERSION__ >= 310 +IA_OP(bfloat, int64_t, ia_i64_bf16) +IA_OP(bfloat, uint32_t, ia_u32_bf16) +IA_OP(bfloat, uint8_t, ia_u8_bf16) +#endif + +IA_OP(half, uint32_t, ia_u32_f16) +IA_OP(half, uint8_t, ia_u8_f16) + +IA_OP(float, int64_t, ia_i64_f32) +IA_OP(uint8_t, int64_t, ia_i64_u8) +IA_OP(int64_t, int64_t, ia_i64_i64) +IA_OP(uint32_t, int64_t, ia_i64_u32) + +IA_OP(float, uint32_t, ia_u32_f32) +IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int64_t, uint32_t, ia_u32_i64) +IA_OP(uint32_t, uint32_t, ia_u32_u32) + +IA_OP(float, uint8_t, ia_u8_f32) +IA_OP(uint8_t, uint8_t, ia_u8_u8) +IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int64_t, uint8_t, ia_u8_i64) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs new file mode 100644 index 00000000..d2c63115 --- /dev/null +++ b/candle-metal-kernels/src/lib.rs @@ -0,0 +1,1246 @@ +#![allow(clippy::too_many_arguments)] +use metal::{ + Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, + MTLSize, +}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::sync::RwLock; + +const AFFINE: &str = include_str!("affine.metal"); +const INDEXING: &str = include_str!("indexing.metal"); +const UNARY: &str = include_str!("unary.metal"); +const BINARY: &str = include_str!("binary.metal"); +const TERNARY: &str = include_str!("ternary.metal"); +const CAST: &str = include_str!("cast.metal"); +const REDUCE: &str = include_str!("reduce.metal"); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Source { + Affine, + Indexing, + Unary, + Binary, + Ternary, + Cast, + Reduce, +} + +macro_rules! ops{ + ($($name:ident),+) => { + + pub mod contiguous { + pub struct Kernel(pub(crate) &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + } + )+ + } + + pub mod strided { + pub struct Kernel(pub(crate) &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + } + )+ + } + }; +} + +pub mod unary { + ops!(cos, sin, exp, sqr, sqrt, neg, copy); +} +pub mod binary { + ops!(add, sub, mul, div); +} + +// static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { +// let mut l = HashMap::new(); +// l.insert("affine", AFFINE); +// l.insert("indexing", INDEXING); +// l.insert("unary", UNARY); +// l +// }); +// +#[derive(thiserror::Error, Debug)] +pub enum MetalKernelError { + #[error("Could not lock kernel map: {0}")] + LockError(String), + #[error("Error while loading library: {0}")] + LoadLibraryError(String), + #[error("Error while loading function: {0}")] + LoadFunctionError(String), +} + +impl From> for MetalKernelError { + fn from(e: std::sync::PoisonError) -> Self { + Self::LockError(e.to_string()) + } +} + +type KernelMap = HashMap<&'static str, T>; +type Libraries = HashMap; +type Functions = KernelMap; + +#[derive(Debug, Default)] +pub struct Kernels { + libraries: RwLock, + funcs: RwLock, +} + +impl Kernels { + pub fn new() -> Self { + let libraries = RwLock::new(Libraries::new()); + let funcs = RwLock::new(Functions::new()); + Self { libraries, funcs } + } + + // pub fn init(device: &Device) -> Result { + // let kernels = Self::new(); + // kernels.load_libraries(device)?; + // Ok(kernels) + // } + + // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { + // for name in LIBRARY_SOURCES.keys() { + // self.load_library(device, name)?; + // } + // Ok(()) + // } + + fn get_library_source(&self, source: Source) -> &'static str { + // LIBRARY_SOURCES.get(name).cloned() + match source { + Source::Affine => AFFINE, + Source::Unary => UNARY, + Source::Binary => BINARY, + Source::Ternary => TERNARY, + Source::Indexing => INDEXING, + Source::Cast => CAST, + Source::Reduce => REDUCE, + } + } + + pub fn load_library( + &self, + device: &Device, + source: Source, + ) -> Result { + let mut libraries = self.libraries.write()?; + if let Some(lib) = libraries.get(&source) { + Ok(lib.clone()) + } else { + let source_content = self.get_library_source(source); + let lib = device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + libraries.insert(source, lib.clone()); + Ok(lib) + } + } + + pub fn load_function( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result { + let mut funcs = self.funcs.write()?; + if let Some(func) = funcs.get(name) { + Ok(func.clone()) + } else { + let func = self + .load_library(device, source)? + .get_function(name, None) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + funcs.insert(name, func.clone()); + Ok(func) + } + } +} + +pub fn call_unary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(input), 0); + encoder.set_buffer(2, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} +pub fn call_unary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: unary::strided::Kernel, + shape: &[usize], + input: &Buffer, + strides: &[usize], + offset: usize, + output: &mut Buffer, + output_offset: usize, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Unary, name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); + encoder.set_bytes( + 2, + std::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + std::mem::size_of_val(strides) as u64, + strides.as_ptr() as *const c_void, + ); + + encoder.set_buffer(4, Some(input), offset as u64); + encoder.set_buffer(5, Some(output), output_offset as u64); + + let width = output.length(); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_binary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: binary::contiguous::Kernel, + length: usize, + left: &Buffer, + right: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(left), 0); + encoder.set_buffer(2, Some(right), 0); + encoder.set_buffer(3, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_binary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: binary::strided::Kernel, + shape: &[usize], + left_input: &Buffer, + left_strides: &[usize], + left_offset: usize, + right_input: &Buffer, + right_strides: &[usize], + right_offset: usize, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Binary, name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); + encoder.set_bytes( + 2, + std::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + std::mem::size_of_val(left_strides) as u64, + left_strides.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 4, + std::mem::size_of_val(right_strides) as u64, + right_strides.as_ptr() as *const c_void, + ); + + encoder.set_buffer(5, Some(left_input), left_offset as u64); + encoder.set_buffer(6, Some(right_input), right_offset as u64); + encoder.set_buffer(7, Some(output), 0); + + let width = output.length(); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_cast_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Cast, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(input), 0); + encoder.set_buffer(2, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_reduce_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + out_length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Reduce, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let elements_to_sum = length / out_length; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + void_ptr(&elements_to_sum), + ); + encoder.set_buffer(2, Some(input), 0); + encoder.set_buffer(3, Some(output), 0); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + (elements_to_sum as u64 + 2 - 1) / 2, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_last_softmax( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Reduce, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + void_ptr(&elements_to_sum), + ); + encoder.set_buffer(2, Some(input), 0); + encoder.set_buffer(3, Some(output), 0); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + // (elements_to_sum as u64 + 2 - 1) / 2, + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn void_ptr(v: &T) -> *const c_void { + (v as *const T).cast() +} + +pub fn call_affine( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + size: usize, + input: &Buffer, + output: &mut Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Affine, "affine_float")?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); + encoder.set_bytes(1, core::mem::size_of::() as u64, void_ptr(&mul)); + encoder.set_bytes(2, core::mem::size_of::() as u64, void_ptr(&add)); + encoder.set_buffer(3, Some(input), 0); + encoder.set_buffer(4, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_where_cond_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + cond: &Buffer, + (cond_stride, cond_offset): (&[usize], usize), + left: &Buffer, + (left_stride, left_offset): (&[usize], usize), + right: &Buffer, + (right_stride, right_offset): (&[usize], usize), + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Ternary, name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let size: usize = shape.iter().product(); + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + void_ptr(&shape.len()), + ); + encoder.set_bytes( + 2, + std::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + std::mem::size_of_val(cond_stride) as u64, + cond_stride.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 4, + std::mem::size_of_val(left_stride) as u64, + left_stride.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 5, + std::mem::size_of_val(right_stride) as u64, + right_stride.as_ptr() as *const c_void, + ); + encoder.set_buffer(6, Some(cond), cond_offset as u64); + encoder.set_buffer(7, Some(left), left_offset as u64); + encoder.set_buffer(8, Some(right), right_offset as u64); + encoder.set_buffer(9, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use half::f16; + use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; + use std::mem; + + fn device() -> Device { + Device::system_default().unwrap() + } + + fn approx(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t * b) / b).collect() + } + + fn approx_f16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() + } + + fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + call_unary_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + + fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let left = device.new_buffer_with_data( + x.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(x) as u64, + options, + ); + let right = device.new_buffer_with_data( + y.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(y) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + call_binary_contiguous( + &device, + command_buffer, + &kernels, + name, + x.len(), + &left, + &right, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(x.len()) + } + + fn run_strided( + v: &[T], + kernel: unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + offset: usize, + ) -> Vec { + let device = device(); + let options = MTLResourceOptions::StorageModeManaged; + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let kernels = Kernels::new(); + call_unary_strided( + &device, + command_buffer, + &kernels, + kernel, + shape, + &input, + strides, + offset, + &mut output, + 0, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + + #[test] + fn cos_f32() { + let v = vec![1.0f32, 2.0, 3.0]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); + assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + + #[test] + fn cos_f32_strided() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + // Shape = [6], strides = [1]; + let shape = vec![6]; + let strides = vec![1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Contiguous + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Transposed + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![1, 3]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Very large + let v = vec![1.0f32; 10_000]; + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + + #[test] + fn binary_add_f32() { + let left = vec![1.0f32, 2.0, 3.0]; + let right = vec![2.0f32, 3.1, 4.2]; + let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); + let expected: Vec<_> = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| x + y) + .collect(); + assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); + assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); + } + + fn cast(v: &[T], name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); + call_cast_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + + #[test] + fn cast_u32_f32() { + let v = vec![1u32, 2, 3]; + let results = cast(&v, "cast_u32_f32"); + let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); + assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); + assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + + fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + + let size = v.len(); + + call_affine( + &device, + command_buffer, + &kernels, + size, + &input, + &mut output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) + } + + #[test] + fn affine() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + + let input = [1.0f32; 40_000]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6; 40_000]); + } + + #[test] + fn index_add() { + let device = Device::system_default().expect("no device found"); + + let options = CompileOptions::new(); + let library = device.new_library_with_source(INDEXING, &options).unwrap(); + + let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let right = [1.0f32; 15]; + let index = [0u32, 4, 2]; + let ids_dim_size = index.len() as u32; + let dst_dim_size: u32 = 15; + let left_size: u32 = 3; + let right_size: u32 = 3; + + let function = library.get_function("ia_u32_f32", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + let options = MTLResourceOptions::StorageModeManaged; + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + let ids_size = (index.len() * mem::size_of::()) as NSUInteger; + let input_size = (left.len() * mem::size_of::()) as NSUInteger; + let output_size = (right.len() * mem::size_of::()) as NSUInteger; + + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); + + let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options); + let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options); + let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options); + + encoder.set_buffer(0, Some(&index_buffer), 0); + encoder.set_buffer(1, Some(&inputs_buffer), 0); + encoder.set_buffer(2, Some(&outputs_buffer), 0); + + encoder.set_bytes(3, 4, void_ptr(&ids_dim_size)); + encoder.set_bytes(4, 4, void_ptr(&left_size)); + encoder.set_bytes(5, 4, void_ptr(&dst_dim_size)); + encoder.set_bytes(6, 4, void_ptr(&right_size)); + + let grid_size = MTLSize { + width: right.len() as NSUInteger, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: pipeline.max_total_threads_per_threadgroup(), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = vec![ + 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, + ]; + let result = outputs_buffer.read_to_vec::(right.len()); + assert_eq!(result, expected); + } + + #[test] + fn cos_f16() { + let v: Vec = [1.0f32, 2.0, 3.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let results = run(&v, unary::contiguous::cos::HALF); + let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); + assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]); + assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + } + + fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = + device.new_buffer((out_length * core::mem::size_of::()) as u64, options); + call_reduce_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + out_length, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(out_length) + } + + fn run_softmax( + v: &[T], + last_dim: usize, + name: &'static str, + ) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + call_last_softmax( + &device, + command_buffer, + &kernels, + name, + v.len(), + last_dim, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) + } + + #[test] + fn reduce_sum() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 1; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![21.0]); + } + + #[test] + fn reduce_sum2() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 2; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![6.0, 15.0]); + } + + #[test] + fn softmax() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 3; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] + ); + } + + fn run_where_cond( + shape: &[usize], + cond: &[I], + (cond_stride, cond_offset): (Vec, usize), + left_true: &[T], + (left_stride, left_offset): (Vec, usize), + right_false: &[T], + (_right_stride, _right_offset): (Vec, usize), + name: &'static str, + ) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let length = cond.len(); + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(cond) as u64, + options, + ); + let left = device.new_buffer_with_data( + left_true.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right_false.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + + let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_where_cond_strided( + &device, + command_buffer, + &kernels, + name, + shape, + &cond, + (&cond_stride, cond_offset), + &left, + (&left_stride, left_offset), + &right, + (&cond_stride, cond_offset), + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(length) + } + + #[test] + fn where_cond() { + let shape = vec![6]; + let cond = vec![0u8, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u8_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); + } +} diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal new file mode 100644 index 00000000..4dfc46c2 --- /dev/null +++ b/candle-metal-kernels/src/reduce.metal @@ -0,0 +1,124 @@ +#include +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +constant int THREADGROUP_SIZE = 256; + +kernel void fast_sum_float( + constant size_t &src_numel, + constant size_t &el_to_sum_per_block, + device const float *src, + device float *dst, + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint blockDim [[ threads_per_threadgroup ]] +) { + + threadgroup float shared_memory[THREADGROUP_SIZE]; + + shared_memory[tid] = 0; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shared_memory[tid] += src[idx]; + idx += blockDim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + dst[dst_id] = shared_memory[0]; +} + +kernel void softmax_float( + constant size_t &src_numel, + constant size_t &el_to_sum_per_block, + device const float *src, + device float *dst, + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint blockDim [[ threads_per_threadgroup ]] +) { + + threadgroup float shared_memory[THREADGROUP_SIZE]; + + shared_memory[tid] = -INFINITY; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + shared_memory[tid] = max(shared_memory[tid], src[idx]); + idx += blockDim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]); + } + threadgroup_barrier(mem_flags::mem_none); + } + + float max = shared_memory[0]; + + shared_memory[tid] = 0; + + // Restart + idx = start_idx + tid; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + const float val = exp(src[idx] - max); + dst[idx] = val; + shared_memory[tid] += val; + idx += blockDim; + } + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + const float inv_acc = 1/shared_memory[0]; + idx = start_idx + tid; + while (idx < stop_idx) { + dst[idx] *= inv_acc; + idx += blockDim; + } +} diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal new file mode 100644 index 00000000..0945b355 --- /dev/null +++ b/candle-metal-kernels/src/ternary.metal @@ -0,0 +1,57 @@ +#include +# +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID_TYPENAME *ids, \ + device const TYPENAME *t, \ + device const TYPENAME *f, \ + device TYPENAME *out ,\ + uint i [[ thread_position_in_grid ]] \ +) { \ + uint strided_i = get_strided_index(i, num_dims, dims, strides); \ + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ +} \ + +// WHERE_OP(float, int64_t, where_i64_f32) +// WHERE_OP(double, int64_t, where_i64_f64) +// WHERE_OP(uint8_t, int64_t, where_i64_u8) +// WHERE_OP(uint32_t, int64_t, where_i64_u32) +// WHERE_OP(int64_t, int64_t, where_i64_i64) +// +// WHERE_OP(float, uint32_t, where_u32_f32) +// WHERE_OP(double, uint32_t, where_u32_f64) +// WHERE_OP(uint8_t, uint32_t, where_u32_u8) +// WHERE_OP(uint32_t, uint32_t, where_u32_u32) +// WHERE_OP(int64_t, uint32_t, where_u32_i64) + +WHERE_OP(float, uint8_t, where_u8_f32) +// WHERE_OP(double, uint8_t, where_u8_f64) +// WHERE_OP(uint8_t, uint8_t, where_u8_u8) +// WHERE_OP(uint32_t, uint8_t, where_u8_u32) +// WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal new file mode 100644 index 00000000..77de214e --- /dev/null +++ b/candle-metal-kernels/src/unary.metal @@ -0,0 +1,82 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +template METAL_FUNC T sqr(T in){ return in * in; } +template METAL_FUNC T neg(T in){ return -in; } +template METAL_FUNC T id(T in){ return in; } + + +using namespace metal; + +#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = TYPENAME(FN(input[i])); \ + } \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ + } \ +} + +#define UNARY_OP(NAME) \ +UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ +UNARY(NAME, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_UNARY_OP(NAME) \ +UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +UNARY_OP(cos) +UNARY_OP(sin) +UNARY_OP(sqr) +UNARY_OP(sqrt) +UNARY_OP(neg) +UNARY_OP(exp) +UNARY(id, float, copy_float, copy_float_strided) +UNARY(id, half, copy_half, copy_half_strided) + +#if __METAL_VERSION__ >= 310 +BFLOAT_UNARY_OP(cos) +BFLOAT_UNARY_OP(sin) +BFLOAT_UNARY_OP(sqr) +BFLOAT_UNARY_OP(sqrt) +BFLOAT_UNARY_OP(neg) +BFLOAT_UNARY_OP(exp) +#endif