From a4c4a564299d89e8b2047ddd34d5daba0c1349e1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 9 Nov 2023 19:30:59 +0100 Subject: [PATCH 01/15] Metal part 1 - Scaffolding for metal. --- candle-core/src/device.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index de57c03a..2665f243 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -1,6 +1,6 @@ use crate::backend::BackendDevice; use crate::cpu_backend::CpuDevice; -use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; +use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType}; /// A `DeviceLocation` represents a physical device whereas multiple `Device` /// can live on the same location (typically for cuda devices). @@ -105,14 +105,14 @@ impl NdArray for Vec { fn shape(&self) -> Result { if self.is_empty() { - crate::bail!("empty array") + bail!("empty array") } let shape0 = self[0].shape()?; let n = self.len(); for v in self.iter() { let shape = v.shape()?; if shape != shape0 { - crate::bail!("two elements have different shapes {shape:?} {shape0:?}") + bail!("two elements have different shapes {shape:?} {shape0:?}") } } Ok(Shape::from([[n].as_slice(), shape0.dims()].concat())) @@ -203,7 +203,11 @@ impl Device { Device::Metal(_device) => { // let storage = device.rand_uniform(shape, dtype, lo, up)?; // Ok(Storage::Metal(storage)) +<<<<<<< HEAD crate::bail!("Metal rand_uniform not implemented") +======= + bail!("Metal rand_uniform not implemented") +>>>>>>> 8cf39d27 (Metal part 1 - Scaffolding for metal.) } } } From 976ad9f9c29bd45d4e7a74298a31b0d85fec623a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 9 Nov 2023 19:41:08 +0100 Subject: [PATCH 02/15] Remove tracing. --- candle-core/src/device.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 2665f243..de57c03a 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -1,6 +1,6 @@ use crate::backend::BackendDevice; use crate::cpu_backend::CpuDevice; -use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType}; +use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; /// A `DeviceLocation` represents a physical device whereas multiple `Device` /// can live on the same location (typically for cuda devices). @@ -105,14 +105,14 @@ impl NdArray for Vec { fn shape(&self) -> Result { if self.is_empty() { - bail!("empty array") + crate::bail!("empty array") } let shape0 = self[0].shape()?; let n = self.len(); for v in self.iter() { let shape = v.shape()?; if shape != shape0 { - bail!("two elements have different shapes {shape:?} {shape0:?}") + crate::bail!("two elements have different shapes {shape:?} {shape0:?}") } } Ok(Shape::from([[n].as_slice(), shape0.dims()].concat())) @@ -203,11 +203,7 @@ impl Device { Device::Metal(_device) => { // let storage = device.rand_uniform(shape, dtype, lo, up)?; // Ok(Storage::Metal(storage)) -<<<<<<< HEAD crate::bail!("Metal rand_uniform not implemented") -======= - bail!("Metal rand_uniform not implemented") ->>>>>>> 8cf39d27 (Metal part 1 - Scaffolding for metal.) } } } From 39406a67214b01f85d5f3e2095ee36eb13d3cbf3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 9 Nov 2023 19:53:14 +0100 Subject: [PATCH 03/15] Adding the actual backend --- Cargo.toml | 1 + candle-core/Cargo.toml | 3 +- candle-core/src/metal_backend.rs | 821 +++++++++++++++ candle-metal-kernels/Cargo.toml | 19 + candle-metal-kernels/README.md | 3 + candle-metal-kernels/src/affine.metal | 46 + candle-metal-kernels/src/binary.metal | 78 ++ candle-metal-kernels/src/cast.metal | 58 ++ candle-metal-kernels/src/indexing.metal | 75 ++ candle-metal-kernels/src/lib.rs | 1246 +++++++++++++++++++++++ candle-metal-kernels/src/reduce.metal | 124 +++ candle-metal-kernels/src/ternary.metal | 57 ++ candle-metal-kernels/src/unary.metal | 82 ++ 13 files changed, 2612 insertions(+), 1 deletion(-) create mode 100644 candle-core/src/metal_backend.rs create mode 100644 candle-metal-kernels/Cargo.toml create mode 100644 candle-metal-kernels/README.md create mode 100644 candle-metal-kernels/src/affine.metal create mode 100644 candle-metal-kernels/src/binary.metal create mode 100644 candle-metal-kernels/src/cast.metal create mode 100644 candle-metal-kernels/src/indexing.metal create mode 100644 candle-metal-kernels/src/lib.rs create mode 100644 candle-metal-kernels/src/reduce.metal create mode 100644 candle-metal-kernels/src/ternary.metal create mode 100644 candle-metal-kernels/src/unary.metal diff --git a/Cargo.toml b/Cargo.toml index 0fea0423..c37bd75b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ exclude = [ "candle-flash-attn", "candle-kernels", + "candle-metal-kernels", "candle-onnx", ] resolver = "2" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 5d5e70a3..592f5bdf 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,6 +13,7 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } @@ -40,4 +41,4 @@ cuda = ["cudarc", "dep:candle-kernels"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] -metal = ["dep:metal"] +metal = ["dep:metal", "dep:candle-metal-kernels"] diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs new file mode 100644 index 00000000..04a2c3dd --- /dev/null +++ b/candle-core/src/metal_backend.rs @@ -0,0 +1,821 @@ +use crate::backend::{BackendDevice, BackendStorage}; +use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; +use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; +use crate::{CpuStorage, DType, Layout, Result, Shape}; +use candle_metal_kernels; +use candle_metal_kernels::{void_ptr, Kernels, Source}; +use core::mem; +use half::{bf16, f16}; +use metal; +use metal::mps::matrix::encode_gemm; +use metal::mps::Float32; +use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger}; +use std::sync::Arc; +use tracing::debug; + +/// Metal related errors +#[derive(thiserror::Error, Debug)] +pub enum MetalError { + #[error("{0}")] + Message(String), + #[error(transparent)] + KernelError(#[from] candle_metal_kernels::MetalKernelError), +} + +impl From for MetalError { + fn from(e: String) -> Self { + MetalError::Message(e) + } +} + +#[derive(Clone)] +pub struct MetalDevice { + device: metal::Device, + command_queue: metal::CommandQueue, + kernels: Arc, +} + +impl std::fmt::Debug for MetalDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetalDevice({:?})", self.device.registry_id()) + } +} + +impl std::ops::Deref for MetalDevice { + type Target = metal::DeviceRef; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl MetalDevice { + // pub fn metal_device(&self) -> &metal::DeviceRef { + // self.device.as_ref() + // } + + pub fn id(&self) -> u64 { + self.registry_id() + } + + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + pub fn kernels(&self) -> &Kernels { + &self.kernels + } + + pub fn device(&self) -> &metal::Device { + &self.device + } + + pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { + let size = (element_count * dtype.size_in_bytes()) as u64; + // debug!("Allocate 1 - buffer size {size}"); + self.device + .new_buffer(size, MTLResourceOptions::StorageModeManaged) + } +} + +#[derive(Debug, Clone)] +pub struct MetalStorage { + buffer: metal::Buffer, + device: MetalDevice, + dtype: DType, +} + +impl BackendStorage for MetalStorage { + type Device = MetalDevice; + + fn try_clone(&self, _: &Layout) -> Result { + Ok(self.clone()) + } + + fn dtype(&self) -> DType { + self.dtype + } + + fn device(&self) -> &Self::Device { + &self.device + } + + fn to_cpu_storage(&self) -> Result { + match self.dtype { + DType::F32 => Ok(CpuStorage::F32( + self.buffer.read_to_vec(self.buffer.length() as usize / 4), + )), + dtype => todo!("Unsupported dtype {dtype:?}"), + } + } + + fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let device = self.device().clone(); + + let shape = layout.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let dtype = self.dtype; + + assert!(layout.is_contiguous()); + assert_eq!(dtype, DType::F32); + + let mut buffer = device.new_buffer(el, self.dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_affine( + &device.device, + &command_buffer, + &device.kernels, + el, + &self.buffer, + &mut buffer, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + return Ok(Self { + buffer, + device: device.clone(), + dtype, + }); + } + + fn powf(&self, _: &Layout, _: f64) -> Result { + todo!() + } + + fn elu(&self, _: &Layout, _: f64) -> Result { + todo!() + } + + fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { + // debug!("TODO reduce_op {op:?} {sum_dims:?}"); + assert!(sum_dims.len() == 1); + assert!(sum_dims[0] == layout.shape().rank() - 1); + assert!(layout.is_contiguous()); + let device = self.device.clone(); + let src_stride = layout.stride(); + let src_dims = layout.shape().dims(); + let src_el: usize = src_dims.iter().product(); + // Source dims and strides with the sum dims at the end. + let mut dims = vec![]; + let mut stride = vec![]; + let mut dst_el: usize = 1; + for (dim_idx, &d) in src_dims.iter().enumerate() { + if !sum_dims.contains(&dim_idx) { + dst_el *= d; + dims.push(d); + stride.push(src_stride[dim_idx]); + } + } + for &dim_idx in sum_dims.iter() { + dims.push(src_dims[dim_idx]); + stride.push(src_stride[dim_idx]); + } + + let el_to_sum_per_block = src_el / dst_el; + // The reduction loop requires the shared array to be properly initialized and for + // this we want the number of threads to be a power of two. + let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); + let (name, check_empty, return_index) = match (op, self.dtype) { + (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), + _ => todo!("Reduce op for non float"), + }; + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } + let dtype = if return_index { DType::U32 } else { self.dtype }; + let mut buffer = device.new_buffer(dst_el, dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_reduce_contiguous( + &device.device, + &command_buffer, + &device.kernels, + name, + src_el, + dst_el, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + + Ok(Self { + buffer, + device, + dtype, + }) + } + + fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { + todo!() + } + + fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { + let device = self.device(); + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32", + (left, right) => todo!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + todo!( + "TODO Implement the kernel calling cast {:?}-{:?}", + self.dtype, + dtype + ); + } + + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "cast {:?} - {:?} - {:?}", + dtype, + self.buffer.length(), + buffer.length() + ); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn unary_impl(&self, layout: &Layout) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + // TODO remove + // return Ok(Self { + // buffer, + // device: device.clone(), + // dtype, + // }); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + todo!("TODO Implement the kernel calling {}", B::KERNEL); + } + + let start = std::time::Instant::now(); + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "Unary {:?} - {:?} - {:?} - {:?}", + B::KERNEL, + start.elapsed(), + self.buffer.length(), + buffer.length() + ); + + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn binary_impl( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = lhs_l.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + use candle_metal_kernels::binary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("add", DType::F32) => contiguous::add::FLOAT, + ("badd", DType::F32) => contiguous::add::FLOAT, + ("sub", DType::F32) => contiguous::sub::FLOAT, + ("bsub", DType::F32) => contiguous::sub::FLOAT, + ("mul", DType::F32) => contiguous::mul::FLOAT, + ("bmul", DType::F32) => contiguous::mul::FLOAT, + ("div", DType::F32) => contiguous::div::FLOAT, + ("bdiv", DType::F32) => contiguous::div::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::binary::strided; + + let kernel_name = match (B::KERNEL, dtype) { + ("badd", DType::F32) => strided::add::FLOAT, + ("bsub", DType::F32) => strided::sub::FLOAT, + ("bmul", DType::F32) => strided::mul::FLOAT, + ("bdiv", DType::F32) => strided::div::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + &lhs_l.stride(), + lhs_l.start_offset(), + &rhs.buffer, + &rhs_l.stride(), + rhs_l.start_offset(), + &mut buffer, + ) + .map_err(MetalError::from)?; + } + + let start = std::time::Instant::now(); + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "Binary {:?} - {:?} - {:?} - {:?}", + B::KERNEL, + start.elapsed(), + self.buffer.length(), + buffer.length() + ); + + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + } + + fn where_cond( + &self, + layout: &Layout, + t: &Self, + t_l: &Layout, + f: &Self, + f_l: &Layout, + ) -> Result { + let device = self.device.clone(); + let shape = t_l.shape(); + let dims = shape.dims(); + let el = shape.elem_count(); + let dtype = t.dtype; + let mut buffer = self.device.new_buffer(el, dtype); + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_where_cond_strided( + &device.device, + &command_buffer, + &device.kernels, + "where_u8_f32", + &dims, + &self.buffer, + (layout.stride(), layout.start_offset()), + &t.buffer, + (&t_l.stride(), t_l.start_offset()), + &f.buffer, + (&f_l.stride(), f_l.start_offset()), + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + Ok(Self { + buffer, + device, + dtype, + }) + } + + fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv1D, + ) -> Result { + todo!() + } + + fn conv2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConv2D, + ) -> Result { + todo!() + } + + fn conv_transpose2d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConvTranspose2D, + ) -> Result { + todo!() + } + + fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + todo!() + } + + fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { + todo!() + } + + fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { + todo!() + } + + fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { + todo!() + } + + fn scatter_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + todo!() + } + + fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + debug!( + "TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}", + self.buffer.length(), + ids.buffer.length(), + ); + let src = self; + let ids_shape = ids_l.shape(); + let ids_dims = ids_shape.dims(); + // let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; + // let src = match src_l.contiguous_offsets() { + // Some((o1, o2)) => src.slice(o1..o2), + // None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, + // }; + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let src_dim_size = src_l.dims()[dim]; + let ids_dim_size = ids_shape.elem_count(); + let dst_el = ids_shape.elem_count() * left_size * right_size; + let dtype = self.dtype; + let device = self.device(); + let buffer = device.new_buffer(dst_el, dtype); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) + // todo!() + } + + fn index_add( + &self, + _: &Layout, + _: &Self, + _: &Layout, + _: &Self, + _: &Layout, + _: usize, + ) -> Result { + todo!() + } + + fn matmul( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let transpose_left = false; + let transpose_right = !rhs_l.is_contiguous(); + let alpha = 1.0; + let beta = 0.0; + self.matmul_generic( + rhs, + (b, m, n, k), + lhs_l, + rhs_l, + transpose_left, + transpose_right, + alpha, + beta, + ) + } + + fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { + let src_shape = src_l.shape(); + let dims = src_shape.dims(); + let el_count = src_shape.elem_count(); + if el_count == 0 { + return Ok(()); + } + if src_l.is_contiguous() { + let command_buffer = self.device.command_queue.new_command_buffer(); + let blip = command_buffer.new_blit_command_encoder(); + blip.copy_from_buffer( + &self.buffer, + src_l.start_offset() as u64, + &dst.buffer, + dst_offset as u64, + self.buffer.length(), + ); + } else { + let command_buffer = self.device.command_queue.new_command_buffer(); + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + dtype => todo!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + &src_l.stride(), + src_l.start_offset(), + &mut dst.buffer, + dst_offset, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + } + Ok(()) + } +} + +impl MetalStorage { + pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self { + Self { + buffer, + device, + dtype, + } + } + pub(crate) fn matmul_generic( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_l: &Layout, + rhs_l: &Layout, + transpose_left: bool, + transpose_right: bool, + alpha: f64, + beta: f64, + ) -> Result { + let elem_count = b * m * n; + match (self.dtype, rhs.dtype) { + (DType::F32, DType::F32) => { + let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); + if b != 1 { + debug!("TODO implement batched matmul for B={b}"); + // bail!("Didn't implemented strided matmul yet"); + return Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }); + } + if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { + debug!( + "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", + lhs_l.is_contiguous(), + rhs_l.is_contiguous(), + rhs_l + ); + return Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }); + } + + debug!("TODO GEMM"); + let command_buffer = self.device.command_queue.new_command_buffer(); + encode_gemm::( + &self.device, + &command_buffer, + transpose_left, + transpose_right, + &self.buffer, + &rhs.buffer, + &mut out_buffer, + m as NSUInteger, + n as NSUInteger, + k as NSUInteger, + alpha as f32, + beta as f32, + Some(b as NSUInteger), + ) + .map_err(MetalError::from)?; + + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + + Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }) + } + _ => todo!("Unimplemented matmul for this pair"), + } + } + + pub fn buffer(&self) -> &Buffer { + &self.buffer + } +} + +impl BackendDevice for MetalDevice { + type Storage = MetalStorage; + + fn new(ordinal: usize) -> Result { + let device = metal::Device::all().swap_remove(ordinal); + + // let capture = metal::CaptureManager::shared(); + // let descriptor = metal::CaptureDescriptor::new(); + // descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + // descriptor.set_capture_device(&device); + // let mut dir = std::env::current_dir()?; + // dir.push("out.gputrace"); + // descriptor.set_output_url(dir); + + // capture + // .start_capture(&descriptor) + // .map_err(MetalError::from)?; + let command_queue = device.new_command_queue(); + // let command_buffer = _command_queue.new_owned_command_buffer(); + let kernels = Arc::new(Kernels::new()); + Ok(Self { + device, + command_queue, + // command_buffer, + kernels, + }) + } + + fn set_seed(&self, _seed: u64) -> Result<()> { + todo!("set_seed") + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Metal + } + + fn same_device(&self, rhs: &Self) -> bool { + self.device.registry_id() == rhs.device.registry_id() + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + // TODO Is there a faster way ? + let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let option = metal::MTLResourceOptions::StorageModeManaged; + let buffer = match storage { + CpuStorage::U8(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::U32(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::I64(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::BF16(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::F16(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::F32(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + CpuStorage::F64(storage) => self.device.new_buffer_with_data( + storage.as_ptr() as *const core::ffi::c_void, + (storage.len() * mem::size_of::()) as u64, + option, + ), + }; + // debug!("Allocate 2 - buffer size {}", buffer.length()); + Ok(Self::Storage { + buffer, + device: self.clone(), + dtype: storage.dtype(), + }) + } + + fn rand_uniform( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + stddev: f64, + ) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) + } + + fn rand_normal( + &self, + shape: &Shape, + dtype: DType, + mean: f64, + stddev: f64, + ) -> Result { + // TODO is there a better way ? + let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; + self.storage_from_cpu_storage(&cpu_storage) + } +} diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml new file mode 100644 index 00000000..ff5ede1a --- /dev/null +++ b/candle-metal-kernels/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "candle-metal-kernels" +version = "0.3.0" +edition = "2021" + +description = "CUDA kernels for Candle" +repository = "https://github.com/huggingface/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" + +[dependencies] +metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +once_cell = "1.18.0" +thiserror = "1" +tracing = "0.1.37" + +[dev-dependencies] +half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } diff --git a/candle-metal-kernels/README.md b/candle-metal-kernels/README.md new file mode 100644 index 00000000..ec923e9a --- /dev/null +++ b/candle-metal-kernels/README.md @@ -0,0 +1,3 @@ +# candle-metal-kernels + +This crate contains Metal kernels used from candle. \ No newline at end of file diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal new file mode 100644 index 00000000..c388c04e --- /dev/null +++ b/candle-metal-kernels/src/affine.metal @@ -0,0 +1,46 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define AFFINE(FN_NAME, TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + constant float &mul, \ + constant float &add, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const TYPENAME m = TYPENAME(mul); \ + const TYPENAME a = TYPENAME(add); \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = input[i] * m + a; \ + } \ +} \ + +AFFINE(affine_float, float) +AFFINE(affine_half, half) + + +#if __METAL_VERSION__ >= 310 +AFFINE(affine_bfloat, bfloat); +#endif diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal new file mode 100644 index 00000000..cfd34416 --- /dev/null +++ b/candle-metal-kernels/src/binary.metal @@ -0,0 +1,78 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + TYPENAME x = left[i]; \ + TYPENAME y = right[i]; \ + output[i] = OUT_TYPENAME(FN); \ + } \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *left_strides, \ + constant size_t *right_strides, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \ + TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \ + output[i] = OUT_TYPENAME(FN); \ + } \ +} + +#define BINARY_OP(FN, NAME) \ +BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ +BINARY(FN, half, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_BINARY_OP(FN, NAME) \ +BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +BINARY_OP(x + y, add) +BINARY_OP(x - y, sub) +BINARY_OP(x * y, mul) +BINARY_OP(x / y, div) + +#if __METAL_VERSION__ >= 310 +BFLOAT_BINARY_OP(x + y, add) +BFLOAT_BINARY_OP(x - y, sub) +BFLOAT_BINARY_OP(x * y, mul) +BFLOAT_BINARY_OP(x / y, div) +#endif diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal new file mode 100644 index 00000000..52e63662 --- /dev/null +++ b/candle-metal-kernels/src/cast.metal @@ -0,0 +1,58 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +using namespace metal; + +#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = RIGHT_TYPENAME(input[i]); \ + } \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + } \ +} + + +CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) + +#if __METAL_VERSION__ >= 310 +#endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal new file mode 100644 index 00000000..528c109d --- /dev/null +++ b/candle-metal-kernels/src/indexing.metal @@ -0,0 +1,75 @@ +#include +using namespace metal; + +template +void index_add( + device I *ids [[buffer(0)]], + device T *inp [[buffer(1)]], + device T *out [[buffer(2)]], + + constant uint &ids_dim_size, + constant uint &left_size, + constant uint &dst_dim_size, + constant uint &right_size, + + uint threadgroup_size [[threads_per_threadgroup]], + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], + uint thread_index [[thread_index_in_threadgroup]] +) { + + const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size); + if (gid >= left_size * right_size) { + return; + } + + const uint i = gid; + const uint pre = i / right_size; + const uint post = i % right_size; + + for (uint j = 0; j < ids_dim_size; j++) { + const uint idx = ids[j]; + const uint src_i = (pre * ids_dim_size + j) * right_size + post; + const uint dst_i = (pre * dst_dim_size + idx) * right_size + post; + out[dst_i] += inp[src_i]; + } +} + +#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + device INDEX_TYPENAME *ids [[buffer(0)]], \ + device TYPENAME *inp [[buffer(1)]], \ + device TYPENAME *out [[buffer(2)]], \ + constant uint &ids_dim_size, \ + constant uint &left_size, \ + constant uint &dst_dim_size, \ + constant uint &right_size, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ + + + +#if __METAL_VERSION__ >= 310 +IA_OP(bfloat, int64_t, ia_i64_bf16) +IA_OP(bfloat, uint32_t, ia_u32_bf16) +IA_OP(bfloat, uint8_t, ia_u8_bf16) +#endif + +IA_OP(half, uint32_t, ia_u32_f16) +IA_OP(half, uint8_t, ia_u8_f16) + +IA_OP(float, int64_t, ia_i64_f32) +IA_OP(uint8_t, int64_t, ia_i64_u8) +IA_OP(int64_t, int64_t, ia_i64_i64) +IA_OP(uint32_t, int64_t, ia_i64_u32) + +IA_OP(float, uint32_t, ia_u32_f32) +IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int64_t, uint32_t, ia_u32_i64) +IA_OP(uint32_t, uint32_t, ia_u32_u32) + +IA_OP(float, uint8_t, ia_u8_f32) +IA_OP(uint8_t, uint8_t, ia_u8_u8) +IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int64_t, uint8_t, ia_u8_i64) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs new file mode 100644 index 00000000..d2c63115 --- /dev/null +++ b/candle-metal-kernels/src/lib.rs @@ -0,0 +1,1246 @@ +#![allow(clippy::too_many_arguments)] +use metal::{ + Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, + MTLSize, +}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::sync::RwLock; + +const AFFINE: &str = include_str!("affine.metal"); +const INDEXING: &str = include_str!("indexing.metal"); +const UNARY: &str = include_str!("unary.metal"); +const BINARY: &str = include_str!("binary.metal"); +const TERNARY: &str = include_str!("ternary.metal"); +const CAST: &str = include_str!("cast.metal"); +const REDUCE: &str = include_str!("reduce.metal"); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Source { + Affine, + Indexing, + Unary, + Binary, + Ternary, + Cast, + Reduce, +} + +macro_rules! ops{ + ($($name:ident),+) => { + + pub mod contiguous { + pub struct Kernel(pub(crate) &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat")); + } + )+ + } + + pub mod strided { + pub struct Kernel(pub(crate) &'static str); + $( + pub mod $name { + use super::Kernel; + pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided")); + pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided")); + pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided")); + } + )+ + } + }; +} + +pub mod unary { + ops!(cos, sin, exp, sqr, sqrt, neg, copy); +} +pub mod binary { + ops!(add, sub, mul, div); +} + +// static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { +// let mut l = HashMap::new(); +// l.insert("affine", AFFINE); +// l.insert("indexing", INDEXING); +// l.insert("unary", UNARY); +// l +// }); +// +#[derive(thiserror::Error, Debug)] +pub enum MetalKernelError { + #[error("Could not lock kernel map: {0}")] + LockError(String), + #[error("Error while loading library: {0}")] + LoadLibraryError(String), + #[error("Error while loading function: {0}")] + LoadFunctionError(String), +} + +impl From> for MetalKernelError { + fn from(e: std::sync::PoisonError) -> Self { + Self::LockError(e.to_string()) + } +} + +type KernelMap = HashMap<&'static str, T>; +type Libraries = HashMap; +type Functions = KernelMap; + +#[derive(Debug, Default)] +pub struct Kernels { + libraries: RwLock, + funcs: RwLock, +} + +impl Kernels { + pub fn new() -> Self { + let libraries = RwLock::new(Libraries::new()); + let funcs = RwLock::new(Functions::new()); + Self { libraries, funcs } + } + + // pub fn init(device: &Device) -> Result { + // let kernels = Self::new(); + // kernels.load_libraries(device)?; + // Ok(kernels) + // } + + // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { + // for name in LIBRARY_SOURCES.keys() { + // self.load_library(device, name)?; + // } + // Ok(()) + // } + + fn get_library_source(&self, source: Source) -> &'static str { + // LIBRARY_SOURCES.get(name).cloned() + match source { + Source::Affine => AFFINE, + Source::Unary => UNARY, + Source::Binary => BINARY, + Source::Ternary => TERNARY, + Source::Indexing => INDEXING, + Source::Cast => CAST, + Source::Reduce => REDUCE, + } + } + + pub fn load_library( + &self, + device: &Device, + source: Source, + ) -> Result { + let mut libraries = self.libraries.write()?; + if let Some(lib) = libraries.get(&source) { + Ok(lib.clone()) + } else { + let source_content = self.get_library_source(source); + let lib = device + .new_library_with_source(source_content, &CompileOptions::new()) + .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?; + libraries.insert(source, lib.clone()); + Ok(lib) + } + } + + pub fn load_function( + &self, + device: &Device, + source: Source, + name: &'static str, + ) -> Result { + let mut funcs = self.funcs.write()?; + if let Some(func) = funcs.get(name) { + Ok(func.clone()) + } else { + let func = self + .load_library(device, source)? + .get_function(name, None) + .map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?; + funcs.insert(name, func.clone()); + Ok(func) + } + } +} + +pub fn call_unary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: unary::contiguous::Kernel, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(input), 0); + encoder.set_buffer(2, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} +pub fn call_unary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: unary::strided::Kernel, + shape: &[usize], + input: &Buffer, + strides: &[usize], + offset: usize, + output: &mut Buffer, + output_offset: usize, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Unary, name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); + encoder.set_bytes( + 2, + std::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + std::mem::size_of_val(strides) as u64, + strides.as_ptr() as *const c_void, + ); + + encoder.set_buffer(4, Some(input), offset as u64); + encoder.set_buffer(5, Some(output), output_offset as u64); + + let width = output.length(); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_binary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: binary::contiguous::Kernel, + length: usize, + left: &Buffer, + right: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(left), 0); + encoder.set_buffer(2, Some(right), 0); + encoder.set_buffer(3, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_binary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: binary::strided::Kernel, + shape: &[usize], + left_input: &Buffer, + left_strides: &[usize], + left_offset: usize, + right_input: &Buffer, + right_strides: &[usize], + right_offset: usize, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Binary, name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len(); + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); + encoder.set_bytes( + 2, + std::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + std::mem::size_of_val(left_strides) as u64, + left_strides.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 4, + std::mem::size_of_val(right_strides) as u64, + right_strides.as_ptr() as *const c_void, + ); + + encoder.set_buffer(5, Some(left_input), left_offset as u64); + encoder.set_buffer(6, Some(right_input), right_offset as u64); + encoder.set_buffer(7, Some(output), 0); + + let width = output.length(); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_cast_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Cast, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(input), 0); + encoder.set_buffer(2, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_reduce_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + out_length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Reduce, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let elements_to_sum = length / out_length; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + void_ptr(&elements_to_sum), + ); + encoder.set_buffer(2, Some(input), 0); + encoder.set_buffer(3, Some(output), 0); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + (elements_to_sum as u64 + 2 - 1) / 2, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_last_softmax( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + elements_to_sum: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Reduce, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + void_ptr(&elements_to_sum), + ); + encoder.set_buffer(2, Some(input), 0); + encoder.set_buffer(3, Some(output), 0); + + let out_length = length / elements_to_sum; + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + // (elements_to_sum as u64 + 2 - 1) / 2, + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn void_ptr(v: &T) -> *const c_void { + (v as *const T).cast() +} + +pub fn call_affine( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + size: usize, + input: &Buffer, + output: &mut Buffer, + mul: f32, + add: f32, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Affine, "affine_float")?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); + encoder.set_bytes(1, core::mem::size_of::() as u64, void_ptr(&mul)); + encoder.set_bytes(2, core::mem::size_of::() as u64, void_ptr(&add)); + encoder.set_buffer(3, Some(input), 0); + encoder.set_buffer(4, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_where_cond_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + cond: &Buffer, + (cond_stride, cond_offset): (&[usize], usize), + left: &Buffer, + (left_stride, left_offset): (&[usize], usize), + right: &Buffer, + (right_stride, right_offset): (&[usize], usize), + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Ternary, name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let size: usize = shape.iter().product(); + encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); + encoder.set_bytes( + 1, + core::mem::size_of::() as u64, + void_ptr(&shape.len()), + ); + encoder.set_bytes( + 2, + std::mem::size_of_val(shape) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + std::mem::size_of_val(cond_stride) as u64, + cond_stride.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 4, + std::mem::size_of_val(left_stride) as u64, + left_stride.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 5, + std::mem::size_of_val(right_stride) as u64, + right_stride.as_ptr() as *const c_void, + ); + encoder.set_buffer(6, Some(cond), cond_offset as u64); + encoder.set_buffer(7, Some(left), left_offset as u64); + encoder.set_buffer(8, Some(right), right_offset as u64); + encoder.set_buffer(9, Some(output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use half::f16; + use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; + use std::mem; + + fn device() -> Device { + Device::system_default().unwrap() + } + + fn approx(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t * b) / b).collect() + } + + fn approx_f16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() + } + + fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + call_unary_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + + fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let left = device.new_buffer_with_data( + x.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(x) as u64, + options, + ); + let right = device.new_buffer_with_data( + y.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(y) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + call_binary_contiguous( + &device, + command_buffer, + &kernels, + name, + x.len(), + &left, + &right, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(x.len()) + } + + fn run_strided( + v: &[T], + kernel: unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + offset: usize, + ) -> Vec { + let device = device(); + let options = MTLResourceOptions::StorageModeManaged; + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let kernels = Kernels::new(); + call_unary_strided( + &device, + command_buffer, + &kernels, + kernel, + shape, + &input, + strides, + offset, + &mut output, + 0, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + + #[test] + fn cos_f32() { + let v = vec![1.0f32, 2.0, 3.0]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); + assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + + #[test] + fn cos_f32_strided() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + // Shape = [6], strides = [1]; + let shape = vec![6]; + let strides = vec![1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Contiguous + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Transposed + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![1, 3]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Very large + let v = vec![1.0f32; 10_000]; + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + + #[test] + fn binary_add_f32() { + let left = vec![1.0f32, 2.0, 3.0]; + let right = vec![2.0f32, 3.1, 4.2]; + let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); + let expected: Vec<_> = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| x + y) + .collect(); + assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); + assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); + } + + fn cast(v: &[T], name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); + call_cast_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + + #[test] + fn cast_u32_f32() { + let v = vec![1u32, 2, 3]; + let results = cast(&v, "cast_u32_f32"); + let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); + assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); + assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + + fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + + let size = v.len(); + + call_affine( + &device, + command_buffer, + &kernels, + size, + &input, + &mut output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) + } + + #[test] + fn affine() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + + let input = [1.0f32; 40_000]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6; 40_000]); + } + + #[test] + fn index_add() { + let device = Device::system_default().expect("no device found"); + + let options = CompileOptions::new(); + let library = device.new_library_with_source(INDEXING, &options).unwrap(); + + let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let right = [1.0f32; 15]; + let index = [0u32, 4, 2]; + let ids_dim_size = index.len() as u32; + let dst_dim_size: u32 = 15; + let left_size: u32 = 3; + let right_size: u32 = 3; + + let function = library.get_function("ia_u32_f32", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + let options = MTLResourceOptions::StorageModeManaged; + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + let ids_size = (index.len() * mem::size_of::()) as NSUInteger; + let input_size = (left.len() * mem::size_of::()) as NSUInteger; + let output_size = (right.len() * mem::size_of::()) as NSUInteger; + + encoder.set_compute_pipeline_state(&pipeline); + encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); + + let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options); + let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options); + let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options); + + encoder.set_buffer(0, Some(&index_buffer), 0); + encoder.set_buffer(1, Some(&inputs_buffer), 0); + encoder.set_buffer(2, Some(&outputs_buffer), 0); + + encoder.set_bytes(3, 4, void_ptr(&ids_dim_size)); + encoder.set_bytes(4, 4, void_ptr(&left_size)); + encoder.set_bytes(5, 4, void_ptr(&dst_dim_size)); + encoder.set_bytes(6, 4, void_ptr(&right_size)); + + let grid_size = MTLSize { + width: right.len() as NSUInteger, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: pipeline.max_total_threads_per_threadgroup(), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = vec![ + 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, + ]; + let result = outputs_buffer.read_to_vec::(right.len()); + assert_eq!(result, expected); + } + + #[test] + fn cos_f16() { + let v: Vec = [1.0f32, 2.0, 3.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let results = run(&v, unary::contiguous::cos::HALF); + let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); + assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]); + assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + } + + fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = + device.new_buffer((out_length * core::mem::size_of::()) as u64, options); + call_reduce_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + out_length, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(out_length) + } + + fn run_softmax( + v: &[T], + last_dim: usize, + name: &'static str, + ) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + call_last_softmax( + &device, + command_buffer, + &kernels, + name, + v.len(), + last_dim, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) + } + + #[test] + fn reduce_sum() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 1; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![21.0]); + } + + #[test] + fn reduce_sum2() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 2; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![6.0, 15.0]); + } + + #[test] + fn softmax() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 3; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] + ); + } + + fn run_where_cond( + shape: &[usize], + cond: &[I], + (cond_stride, cond_offset): (Vec, usize), + left_true: &[T], + (left_stride, left_offset): (Vec, usize), + right_false: &[T], + (_right_stride, _right_offset): (Vec, usize), + name: &'static str, + ) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let length = cond.len(); + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(cond) as u64, + options, + ); + let left = device.new_buffer_with_data( + left_true.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right_false.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + + let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_where_cond_strided( + &device, + command_buffer, + &kernels, + name, + shape, + &cond, + (&cond_stride, cond_offset), + &left, + (&left_stride, left_offset), + &right, + (&cond_stride, cond_offset), + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(length) + } + + #[test] + fn where_cond() { + let shape = vec![6]; + let cond = vec![0u8, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u8_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); + } +} diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal new file mode 100644 index 00000000..4dfc46c2 --- /dev/null +++ b/candle-metal-kernels/src/reduce.metal @@ -0,0 +1,124 @@ +#include +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +constant int THREADGROUP_SIZE = 256; + +kernel void fast_sum_float( + constant size_t &src_numel, + constant size_t &el_to_sum_per_block, + device const float *src, + device float *dst, + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint blockDim [[ threads_per_threadgroup ]] +) { + + threadgroup float shared_memory[THREADGROUP_SIZE]; + + shared_memory[tid] = 0; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shared_memory[tid] += src[idx]; + idx += blockDim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + dst[dst_id] = shared_memory[0]; +} + +kernel void softmax_float( + constant size_t &src_numel, + constant size_t &el_to_sum_per_block, + device const float *src, + device float *dst, + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint blockDim [[ threads_per_threadgroup ]] +) { + + threadgroup float shared_memory[THREADGROUP_SIZE]; + + shared_memory[tid] = -INFINITY; + // Elements summed in this block range from dst_id * el_to_sum_per_block + // to (dst_id + 1) * el_to_sum_per_block. + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + shared_memory[tid] = max(shared_memory[tid], src[idx]); + idx += blockDim; + } + + threadgroup_barrier(mem_flags::mem_none); + + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] = max(shared_memory[tid], shared_memory[tid + s]); + } + threadgroup_barrier(mem_flags::mem_none); + } + + float max = shared_memory[0]; + + shared_memory[tid] = 0; + + // Restart + idx = start_idx + tid; + while (idx < stop_idx) { + // TODO: Fast version for the contiguous case. + const float val = exp(src[idx] - max); + dst[idx] = val; + shared_memory[tid] += val; + idx += blockDim; + } + // reduction in shared memory + for (uint s = blockDim / 2; s > 0; s >>= 1) { + if (tid < s) { + shared_memory[tid] += shared_memory[tid + s]; + } + threadgroup_barrier(mem_flags::mem_none); + } + + const float inv_acc = 1/shared_memory[0]; + idx = start_idx + tid; + while (idx < stop_idx) { + dst[idx] *= inv_acc; + idx += blockDim; + } +} diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal new file mode 100644 index 00000000..0945b355 --- /dev/null +++ b/candle-metal-kernels/src/ternary.metal @@ -0,0 +1,57 @@ +#include +# +using namespace metal; + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +#define WHERE_OP(TYPENAME, ID_TYPENAME, FN_NAME) \ +kernel void FN_NAME( \ + constant size_t &numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t *strides_t, \ + constant size_t *strides_f, \ + device const ID_TYPENAME *ids, \ + device const TYPENAME *t, \ + device const TYPENAME *f, \ + device TYPENAME *out ,\ + uint i [[ thread_position_in_grid ]] \ +) { \ + uint strided_i = get_strided_index(i, num_dims, dims, strides); \ + uint strided_i_t = get_strided_index(i, num_dims, dims, strides_t); \ + uint strided_i_f = get_strided_index(i, num_dims, dims, strides_f); \ + out[i] = ids[strided_i] ? t[strided_i_t] : f[strided_i_f]; \ +} \ + +// WHERE_OP(float, int64_t, where_i64_f32) +// WHERE_OP(double, int64_t, where_i64_f64) +// WHERE_OP(uint8_t, int64_t, where_i64_u8) +// WHERE_OP(uint32_t, int64_t, where_i64_u32) +// WHERE_OP(int64_t, int64_t, where_i64_i64) +// +// WHERE_OP(float, uint32_t, where_u32_f32) +// WHERE_OP(double, uint32_t, where_u32_f64) +// WHERE_OP(uint8_t, uint32_t, where_u32_u8) +// WHERE_OP(uint32_t, uint32_t, where_u32_u32) +// WHERE_OP(int64_t, uint32_t, where_u32_i64) + +WHERE_OP(float, uint8_t, where_u8_f32) +// WHERE_OP(double, uint8_t, where_u8_f64) +// WHERE_OP(uint8_t, uint8_t, where_u8_u8) +// WHERE_OP(uint32_t, uint8_t, where_u8_u32) +// WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal new file mode 100644 index 00000000..77de214e --- /dev/null +++ b/candle-metal-kernels/src/unary.metal @@ -0,0 +1,82 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +template METAL_FUNC T sqr(T in){ return in * in; } +template METAL_FUNC T neg(T in){ return -in; } +template METAL_FUNC T id(T in){ return in; } + + +using namespace metal; + +#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = TYPENAME(FN(input[i])); \ + } \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const TYPENAME *input, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ + } \ +} + +#define UNARY_OP(NAME) \ +UNARY(NAME, float, NAME##_float, NAME##_float_strided); \ +UNARY(NAME, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_UNARY_OP(NAME) \ +UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +UNARY_OP(cos) +UNARY_OP(sin) +UNARY_OP(sqr) +UNARY_OP(sqrt) +UNARY_OP(neg) +UNARY_OP(exp) +UNARY(id, float, copy_float, copy_float_strided) +UNARY(id, half, copy_half, copy_half_strided) + +#if __METAL_VERSION__ >= 310 +BFLOAT_UNARY_OP(cos) +BFLOAT_UNARY_OP(sin) +BFLOAT_UNARY_OP(sqr) +BFLOAT_UNARY_OP(sqrt) +BFLOAT_UNARY_OP(neg) +BFLOAT_UNARY_OP(exp) +#endif From df6814f34ef8cbbf0b5e9e98fc8a71690cf8e8a4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 01:24:49 +0100 Subject: [PATCH 04/15] Refactor to simplify our lives for settings the params in the encoder. --- candle-core/src/device.rs | 1 + candle-core/src/lib.rs | 2 + candle-core/src/metal_backend.rs | 130 +++----- candle-core/src/tensor.rs | 4 + candle-metal-kernels/src/indexing.metal | 33 ++ candle-metal-kernels/src/lib.rs | 424 ++++++++++++++---------- 6 files changed, 339 insertions(+), 255 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index de57c03a..73eb9640 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -146,6 +146,7 @@ impl Device { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), + (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs), _ => false, } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index da61bdb5..36f5f6b1 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -53,6 +53,8 @@ mod dummy_metal_backend; pub mod error; mod indexer; pub mod layout; +#[cfg(feature = "metal")] +pub mod metal_backend; #[cfg(feature = "mkl")] mod mkl; pub mod npy; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 04a2c3dd..68a96672 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1,17 +1,16 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; +use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::{void_ptr, Kernels, Source}; +use candle_metal_kernels::Kernels; use core::mem; use half::{bf16, f16}; use metal; use metal::mps::matrix::encode_gemm; use metal::mps::Float32; -use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger}; +use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::sync::Arc; -use tracing::debug; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -113,7 +112,6 @@ impl BackendStorage for MetalStorage { let device = self.device().clone(); let shape = layout.shape(); - let dims = shape.dims(); let el = shape.elem_count(); let dtype = self.dtype; @@ -174,10 +172,8 @@ impl BackendStorage for MetalStorage { stride.push(src_stride[dim_idx]); } - let el_to_sum_per_block = src_el / dst_el; // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. - let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); let (name, check_empty, return_index) = match (op, self.dtype) { (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), @@ -219,13 +215,10 @@ impl BackendStorage for MetalStorage { fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { let device = self.device(); let shape = layout.shape(); - let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); if layout.is_contiguous() { - use candle_metal_kernels::unary::contiguous; - let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", (left, right) => todo!("to dtype {left:?} - {right:?}"), @@ -250,12 +243,12 @@ impl BackendStorage for MetalStorage { command_buffer.commit(); // command_buffer.wait_until_scheduled(); - debug!( - "cast {:?} - {:?} - {:?}", - dtype, - self.buffer.length(), - buffer.length() - ); + // debug!( + // "cast {:?} - {:?} - {:?}", + // dtype, + // self.buffer.length(), + // buffer.length() + // ); Ok(Self { buffer, device: device.clone(), @@ -267,15 +260,8 @@ impl BackendStorage for MetalStorage { let device = self.device(); let dtype = self.dtype; let shape = layout.shape(); - let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - // TODO remove - // return Ok(Self { - // buffer, - // device: device.clone(), - // dtype, - // }); let command_buffer = device.command_queue.new_command_buffer(); if layout.is_contiguous() { use candle_metal_kernels::unary::contiguous; @@ -302,17 +288,7 @@ impl BackendStorage for MetalStorage { } else { todo!("TODO Implement the kernel calling {}", B::KERNEL); } - - let start = std::time::Instant::now(); command_buffer.commit(); - // command_buffer.wait_until_scheduled(); - debug!( - "Unary {:?} - {:?} - {:?} - {:?}", - B::KERNEL, - start.elapsed(), - self.buffer.length(), - buffer.length() - ); Ok(Self { buffer, @@ -330,7 +306,6 @@ impl BackendStorage for MetalStorage { let device = self.device(); let dtype = self.dtype; let shape = lhs_l.shape(); - let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); @@ -385,17 +360,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - - let start = std::time::Instant::now(); command_buffer.commit(); - // command_buffer.wait_until_scheduled(); - debug!( - "Binary {:?} - {:?} - {:?} - {:?}", - B::KERNEL, - start.elapsed(), - self.buffer.length(), - buffer.length() - ); Ok(Self { buffer, @@ -452,6 +417,16 @@ impl BackendStorage for MetalStorage { todo!() } + fn conv_transpose1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConvTranspose1D, + ) -> Result { + todo!() + } + fn conv2d( &self, _l: &Layout, @@ -504,34 +479,28 @@ impl BackendStorage for MetalStorage { todo!() } - fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - debug!( - "TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}", - self.buffer.length(), - ids.buffer.length(), - ); - let src = self; - let ids_shape = ids_l.shape(); - let ids_dims = ids_shape.dims(); - // let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; - // let src = match src_l.contiguous_offsets() { - // Some((o1, o2)) => src.slice(o1..o2), - // None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, - // }; - let left_size: usize = src_l.dims()[..dim].iter().product(); - let right_size: usize = src_l.dims()[dim + 1..].iter().product(); - let src_dim_size = src_l.dims()[dim]; - let ids_dim_size = ids_shape.elem_count(); - let dst_el = ids_shape.elem_count() * left_size * right_size; - let dtype = self.dtype; - let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) - // todo!() + fn index_select( + &self, + _ids: &Self, + _src_l: &Layout, + _ids_l: &Layout, + _dim: usize, + ) -> Result { + todo!("Index select"); + // let ids_shape = ids_l.shape(); + // let left_size: usize = src_l.dims()[..dim].iter().product(); + // let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + // let src_dim_size = src_l.dims()[dim]; + // let ids_dim_size = ids_shape.elem_count(); + // let dst_el = ids_shape.elem_count() * left_size * right_size; + // let dtype = self.dtype; + // let device = self.device(); + // let buffer = device.new_buffer(dst_el, dtype); + // Ok(Self { + // buffer, + // device: device.clone(), + // dtype, + // }) } fn index_add( @@ -571,7 +540,6 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); - let dims = src_shape.dims(); let el_count = src_shape.elem_count(); if el_count == 0 { return Ok(()); @@ -637,7 +605,7 @@ impl MetalStorage { (DType::F32, DType::F32) => { let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); if b != 1 { - debug!("TODO implement batched matmul for B={b}"); + // debug!("TODO implement batched matmul for B={b}"); // bail!("Didn't implemented strided matmul yet"); return Ok(Self { buffer: out_buffer, @@ -646,12 +614,12 @@ impl MetalStorage { }); } if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - debug!( - "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", - lhs_l.is_contiguous(), - rhs_l.is_contiguous(), - rhs_l - ); + // debug!( + // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", + // lhs_l.is_contiguous(), + // rhs_l.is_contiguous(), + // rhs_l + // ); return Ok(Self { buffer: out_buffer, device: self.device.clone(), @@ -659,7 +627,7 @@ impl MetalStorage { }); } - debug!("TODO GEMM"); + // debug!("TODO GEMM"); let command_buffer = self.device.command_queue.new_command_buffer(); encode_gemm::( &self.device, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 2a0924b6..f7f66668 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1859,7 +1859,11 @@ impl Tensor { (Storage::Cpu(storage), Device::Cuda(cuda)) => { Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) } + (Storage::Cpu(storage), Device::Metal(metal)) => { + Storage::Metal(metal.storage_from_cpu_storage(storage)?) + } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 528c109d..eefaef34 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,6 +1,39 @@ #include using namespace metal; +kernel void is_u32_f32( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + + const device float *input, + const device uint *input_ids, + device float *output, + + uint gid [[ thread_position_in_grid ]] +) { + + if (gid >= dst_size) { + return; + } + + const size_t id_i = gid / right_size / left_size; + const size_t right_rank_i = gid % right_size; + const size_t left_rank_i = gid % left_size; + + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1)); + const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; + + output[gid] = input[src_i]; + +} + + template void index_add( device I *ids [[buffer(0)]], diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d2c63115..1bcd56d1 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, - MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, + Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -15,6 +15,70 @@ const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { +

::set_param(encoder, position, data) +} +trait EncoderParam { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); +} +macro_rules! primitive { + ($type:ty) => { + impl EncoderParam for $type { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::<$type>() as u64, + &data as *const $type as *const c_void, + ); + } + } + }; +} +primitive!(usize); +primitive!(u32); +primitive!(f32); + +impl EncoderParam for &[T] { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + (core::mem::size_of::() * data.len()) as u64, + data.as_ptr() as *const T as *const c_void, + ); + } +} + +impl EncoderParam for &Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} +impl EncoderParam for (&Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} +impl EncoderParam for &mut Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} +impl EncoderParam for (&mut Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +macro_rules! set_params { + ($encoder:ident, ($($param:expr),+)) => ( + let mut _index = 0; + $( + set_param($encoder, _index, $param); + _index += 1; + )* + ); +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, @@ -191,9 +255,7 @@ pub fn call_unary_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, 4, void_ptr(&length)); - encoder.set_buffer(1, Some(input), 0); - encoder.set_buffer(2, Some(output), 0); + set_params!(encoder, (length, input, output)); let thread_group_count = MTLSize { width: 1, @@ -239,24 +301,19 @@ pub fn call_unary_strided( encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); - encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); - encoder.set_bytes( - 2, - std::mem::size_of_val(shape) as u64, - shape.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 3, - std::mem::size_of_val(strides) as u64, - strides.as_ptr() as *const c_void, + set_params!( + encoder, + ( + length, + num_dims, + shape, + strides, + (input, offset), + (output, output_offset) + ) ); - encoder.set_buffer(4, Some(input), offset as u64); - encoder.set_buffer(5, Some(output), output_offset as u64); - - let width = output.length(); - + let width: usize = shape.iter().product(); let thread_group_count = MTLSize { width: 1, height: 1, @@ -264,7 +321,7 @@ pub fn call_unary_strided( }; let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), height: 1, depth: 1, }; @@ -299,10 +356,7 @@ pub fn call_binary_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, 4, void_ptr(&length)); - encoder.set_buffer(1, Some(left), 0); - encoder.set_buffer(2, Some(right), 0); - encoder.set_buffer(3, Some(output), 0); + set_params!(encoder, (length, left, right, output)); let thread_group_count = MTLSize { width: 1, @@ -348,32 +402,24 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + let width: usize = shape.iter().product(); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); - encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); - encoder.set_bytes( - 2, - std::mem::size_of_val(shape) as u64, - shape.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 3, - std::mem::size_of_val(left_strides) as u64, - left_strides.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 4, - std::mem::size_of_val(right_strides) as u64, - right_strides.as_ptr() as *const c_void, - ); - encoder.set_buffer(5, Some(left_input), left_offset as u64); - encoder.set_buffer(6, Some(right_input), right_offset as u64); - encoder.set_buffer(7, Some(output), 0); - - let width = output.length(); + set_params!( + encoder, + ( + length, + num_dims, + shape, + left_strides, + right_strides, + (left_input, left_offset), + (right_input, right_offset), + output + ) + ); let thread_group_count = MTLSize { width: 1, @@ -382,7 +428,7 @@ pub fn call_binary_strided( }; let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), height: 1, depth: 1, }; @@ -416,9 +462,7 @@ pub fn call_cast_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, 4, void_ptr(&length)); - encoder.set_buffer(1, Some(input), 0); - encoder.set_buffer(2, Some(output), 0); + set_params!(encoder, (length, input, output)); let thread_group_count = MTLSize { width: 1, @@ -463,14 +507,7 @@ pub fn call_reduce_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes( - 1, - core::mem::size_of::() as u64, - void_ptr(&elements_to_sum), - ); - encoder.set_buffer(2, Some(input), 0); - encoder.set_buffer(3, Some(output), 0); + set_params!(encoder, (length, elements_to_sum, input, output)); let thread_group_count = MTLSize { width: out_length as u64, @@ -518,14 +555,7 @@ pub fn call_last_softmax( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes( - 1, - core::mem::size_of::() as u64, - void_ptr(&elements_to_sum), - ); - encoder.set_buffer(2, Some(input), 0); - encoder.set_buffer(3, Some(output), 0); + set_params!(encoder, (length, elements_to_sum, input, output)); let out_length = length / elements_to_sum; @@ -553,10 +583,6 @@ pub fn call_last_softmax( Ok(()) } -pub fn void_ptr(v: &T) -> *const c_void { - (v as *const T).cast() -} - pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, @@ -580,11 +606,7 @@ pub fn call_affine( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); - encoder.set_bytes(1, core::mem::size_of::() as u64, void_ptr(&mul)); - encoder.set_bytes(2, core::mem::size_of::() as u64, void_ptr(&add)); - encoder.set_buffer(3, Some(input), 0); - encoder.set_buffer(4, Some(output), 0); + set_params!(encoder, (size, mul, add, input, output)); let thread_group_count = MTLSize { width: 1, @@ -632,36 +654,23 @@ pub fn call_where_cond_strided( encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); - encoder.set_bytes( - 1, - core::mem::size_of::() as u64, - void_ptr(&shape.len()), + let rank = shape.len(); + + set_params!( + encoder, + ( + size, + rank, + shape, + cond_stride, + left_stride, + right_stride, + (cond, cond_offset), + (left, left_offset), + (right, right_offset), + output + ) ); - encoder.set_bytes( - 2, - std::mem::size_of_val(shape) as u64, - shape.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 3, - std::mem::size_of_val(cond_stride) as u64, - cond_stride.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 4, - std::mem::size_of_val(left_stride) as u64, - left_stride.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 5, - std::mem::size_of_val(right_stride) as u64, - right_stride.as_ptr() as *const c_void, - ); - encoder.set_buffer(6, Some(cond), cond_offset as u64); - encoder.set_buffer(7, Some(left), left_offset as u64); - encoder.set_buffer(8, Some(right), right_offset as u64); - encoder.set_buffer(9, Some(output), 0); let thread_group_count = MTLSize { width: 1, @@ -686,7 +695,13 @@ mod tests { use super::*; use half::f16; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; - use std::mem; + + fn new_buffer(device: &Device, data: &[T]) -> Buffer { + let options = MTLResourceOptions::StorageModeManaged; + let ptr = data.as_ptr() as *const core::ffi::c_void; + let size = (data.len() * std::mem::size_of::()) as u64; + device.new_buffer_with_data(ptr, size, options) + } fn device() -> Device { Device::system_default().unwrap() @@ -707,13 +722,8 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -735,16 +745,8 @@ mod tests { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; - let left = device.new_buffer_with_data( - x.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(x) as u64, - options, - ); - let right = device.new_buffer_with_data( - y.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(y) as u64, - options, - ); + let left = new_buffer(&device, x); + let right = new_buffer(&device, y); let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, @@ -770,15 +772,10 @@ mod tests { offset: usize, ) -> Vec { let device = device(); - let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); let kernels = Kernels::new(); call_unary_strided( &device, @@ -893,13 +890,9 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_cast_contiguous( &device, command_buffer, @@ -935,14 +928,9 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); let size = v.len(); @@ -978,6 +966,104 @@ mod tests { assert_eq!(result, vec![2.6; 40_000]); } + #[test] + fn index_select() { + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [2, 5]; + let ids = [0u32, 1, 0]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 1, 0]; + let dim = 1; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); + } + + fn run_index_select( + embeddings: &[T], + shape: &[usize], + ids: &[I], + dim: usize, + ) -> Vec { + let device = Device::system_default().expect("no device found"); + let options = CompileOptions::new(); + let library = device.new_library_with_source(INDEXING, &options).unwrap(); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids.len() * left_size * right_size; + let ids_size = ids.len(); + + let function = library.get_function("is_u32_f32", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + let embeddings_buffer = new_buffer(&device, &embeddings); + let ids_buffer = new_buffer(&device, &ids); + let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + &embeddings_buffer, + &ids_buffer, + &mut dst_buffer + ) + ); + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); + let grid_size = MTLSize { + width: (dst_el as u64 + width - 1) / width, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + println!("{width:?} - {:?}", grid_size); + + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + dst_buffer.read_to_vec::(dst_el) + } + #[test] fn index_add() { let device = Device::system_default().expect("no device found"); @@ -997,31 +1083,29 @@ mod tests { let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); - let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let encoder = command_buffer.new_compute_command_encoder(); - let ids_size = (index.len() * mem::size_of::()) as NSUInteger; - let input_size = (left.len() * mem::size_of::()) as NSUInteger; - let output_size = (right.len() * mem::size_of::()) as NSUInteger; - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); - let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options); - let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options); - let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options); + let index_buffer = new_buffer(&device, &index); + let inputs_buffer = new_buffer(&device, &left); + let outputs_buffer = new_buffer(&device, &right); - encoder.set_buffer(0, Some(&index_buffer), 0); - encoder.set_buffer(1, Some(&inputs_buffer), 0); - encoder.set_buffer(2, Some(&outputs_buffer), 0); - - encoder.set_bytes(3, 4, void_ptr(&ids_dim_size)); - encoder.set_bytes(4, 4, void_ptr(&left_size)); - encoder.set_bytes(5, 4, void_ptr(&dst_dim_size)); - encoder.set_bytes(6, 4, void_ptr(&right_size)); + set_params!( + encoder, + ( + &index_buffer, + &inputs_buffer, + &outputs_buffer, + ids_dim_size, + left_size, + dst_dim_size, + right_size + ) + ); let grid_size = MTLSize { width: right.len() as NSUInteger, @@ -1064,12 +1148,9 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); let mut output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); call_reduce_contiguous( @@ -1098,13 +1179,8 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); call_last_softmax( &device, command_buffer, From f82bf2d915cf4494fb137d9ec2c0566621f8007b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 01:58:51 +0100 Subject: [PATCH 05/15] Adding indexing. Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> --- candle-core/src/metal_backend.rs | 56 ++++++----- candle-metal-kernels/src/indexing.metal | 58 ++++++----- candle-metal-kernels/src/lib.rs | 122 +++++++++++++++--------- candle-metal-kernels/src/reduce.metal | 93 ++++++++++-------- 4 files changed, 191 insertions(+), 138 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 68a96672..ed592240 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -479,28 +479,40 @@ impl BackendStorage for MetalStorage { todo!() } - fn index_select( - &self, - _ids: &Self, - _src_l: &Layout, - _ids_l: &Layout, - _dim: usize, - ) -> Result { - todo!("Index select"); - // let ids_shape = ids_l.shape(); - // let left_size: usize = src_l.dims()[..dim].iter().product(); - // let right_size: usize = src_l.dims()[dim + 1..].iter().product(); - // let src_dim_size = src_l.dims()[dim]; - // let ids_dim_size = ids_shape.elem_count(); - // let dst_el = ids_shape.elem_count() * left_size * right_size; - // let dtype = self.dtype; - // let device = self.device(); - // let buffer = device.new_buffer(dst_el, dtype); - // Ok(Self { - // buffer, - // device: device.clone(), - // dtype, - // }) + fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { + assert!(src_l.is_contiguous()); + assert!(ids_l.is_contiguous()); + let left_size: usize = src_l.dims()[..dim].iter().product(); + let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + let ids_el = ids_l.shape().elem_count(); + let dst_el = ids_el * left_size * right_size; + let dtype = self.dtype; + let device = self.device(); + let mut buffer = device.new_buffer(dst_el, dtype); + let name = match (ids.dtype, self.dtype) { + (DType::U32, DType::F32) => "is_u32_f32", + (left, right) => todo!("index select metal {left:?} {right:?}"), + }; + let command_buffer = self.device.command_queue.new_command_buffer(); + candle_metal_kernels::call_index_select( + &device.device, + &command_buffer, + &self.device.kernels, + name, + src_l.dims(), + ids_el, + dim, + &self.buffer, + &ids.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) } fn index_add( diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index eefaef34..c077cc48 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,39 +1,36 @@ #include using namespace metal; -kernel void is_u32_f32( - constant size_t &dst_size, - constant size_t &left_size, - constant size_t &src_dim_size, - constant size_t &right_size, - constant size_t &ids_size, - - const device float *input, - const device uint *input_ids, - device float *output, - - uint gid [[ thread_position_in_grid ]] -) { - - if (gid >= dst_size) { - return; - } - - const size_t id_i = gid / right_size / left_size; - const size_t right_rank_i = gid % right_size; - const size_t left_rank_i = gid % left_size; - - // Force prevent out of bounds indexing - // since there doesn't seem to be a good way to force crash - // No need to check for zero we're only allowing unsized. - const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1)); - const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; - - output[gid] = input[src_i]; - +# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &dst_size, \ + constant size_t &left_size, \ + constant size_t &src_dim_size, \ + constant size_t &right_size, \ + constant size_t &ids_size, \ + const device TYPENAME *input, \ + const device INDEX_TYPENAME *input_ids, \ + device TYPENAME *output, \ + uint gid [[ thread_position_in_grid ]] \ +) { \ + if (gid >= dst_size) { \ + return; \ + } \ + const size_t id_i = gid / right_size / left_size; \ + const size_t right_rank_i = gid % right_size; \ + const size_t left_rank_i = gid % left_size; \ + /* \ + // Force prevent out of bounds indexing \ + // since there doesn't seem to be a good way to force crash \ + // No need to check for zero we're only allowing unsized. \ + */ \ + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \ + const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; \ + output[gid] = input[src_i]; \ } + template void index_add( device I *ids [[buffer(0)]], @@ -82,6 +79,7 @@ kernel void FN_NAME( \ ) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ +INDEX_OP(is_u32_f32, uint, float) #if __METAL_VERSION__ >= 310 IA_OP(bfloat, int64_t, ia_i64_bf16) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1bcd56d1..6a01107c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -690,6 +690,63 @@ pub fn call_where_cond_strided( Ok(()) } +pub fn call_index_select( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + ids_size: usize, + dim: usize, + input: &Buffer, + ids: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids_size * left_size * right_size; + + let func = kernels.load_function(device, Source::Indexing, name)?; + let pipeline = device + .new_compute_pipeline_state_with_function(&func) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + input, + ids, + output + ) + ); + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); + let grid_size = MTLSize { + width: (dst_el as u64 + width - 1) / width, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -1003,61 +1060,32 @@ mod tests { dim: usize, ) -> Vec { let device = Device::system_default().expect("no device found"); - let options = CompileOptions::new(); - let library = device.new_library_with_source(INDEXING, &options).unwrap(); - - let left_size: usize = shape[..dim].iter().product(); - let right_size: usize = shape[dim + 1..].iter().product(); - let src_dim_size = shape[dim]; - let dst_el = ids.len() * left_size * right_size; - let ids_size = ids.len(); - - let function = library.get_function("is_u32_f32", None).unwrap(); - let pipeline = device - .new_compute_pipeline_state_with_function(&function) - .unwrap(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let encoder = command_buffer.new_compute_command_encoder(); - - encoder.set_compute_pipeline_state(&pipeline); - let embeddings_buffer = new_buffer(&device, &embeddings); let ids_buffer = new_buffer(&device, &ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - set_params!( - encoder, - ( - dst_el, - left_size, - src_dim_size, - right_size, - ids_size, - &embeddings_buffer, - &ids_buffer, - &mut dst_buffer - ) - ); + let kernels = Kernels::new(); + call_index_select( + &device, + &command_buffer, + &kernels, + "is_u32_f32", + shape, + ids.len(), + dim, + &embeddings_buffer, + &ids_buffer, + &mut dst_buffer, + ) + .unwrap(); - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); - let grid_size = MTLSize { - width: (dst_el as u64 + width - 1) / width, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - - println!("{width:?} - {:?}", grid_size); - - encoder.dispatch_thread_groups(grid_size, thread_group_size); - encoder.end_encoding(); command_buffer.commit(); command_buffer.wait_until_completed(); diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 4dfc46c2..c6984474 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -18,45 +18,55 @@ METAL_FUNC uint get_strided_index( constant int THREADGROUP_SIZE = 256; -kernel void fast_sum_float( - constant size_t &src_numel, - constant size_t &el_to_sum_per_block, - device const float *src, - device float *dst, - uint id [[ thread_position_in_grid ]], - uint tid [[ thread_index_in_threadgroup ]], - uint dst_id [[ threadgroup_position_in_grid ]], - uint blockDim [[ threads_per_threadgroup ]] -) { - - threadgroup float shared_memory[THREADGROUP_SIZE]; - - shared_memory[tid] = 0; - // Elements summed in this block range from dst_id * el_to_sum_per_block - // to (dst_id + 1) * el_to_sum_per_block. - size_t start_idx = dst_id * el_to_sum_per_block; - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); - size_t idx = start_idx + tid; - - while (idx < stop_idx) { - // TODO: Fast version for the contiguous case. - // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); - shared_memory[tid] += src[idx]; - idx += blockDim; - } - - threadgroup_barrier(mem_flags::mem_none); - - // reduction in shared memory - for (uint s = blockDim / 2; s > 0; s >>= 1) { - if (tid < s) { - shared_memory[tid] += shared_memory[tid + s]; - } - threadgroup_barrier(mem_flags::mem_none); - } - - dst[dst_id] = shared_memory[0]; -} +# define REDUCE(FN, NAME, TYPENAME) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const TYPENAME *src, \ + device TYPENAME *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint blockDim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup float shared_memory[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = 0; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + */ \ + TYPENAME x = shared_memory[tid]; \ + TYPENAME y = src[idx]; \ + shared_memory[tid] = FN; \ + idx += blockDim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = blockDim / 2; s > 0; s >>= 1) { \ + if (tid < s) { \ + TYPENAME x = shared_memory[tid]; \ + TYPENAME y = shared_memory[tid + s]; \ + shared_memory[tid] = FN; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + dst[dst_id] = shared_memory[0]; \ +} \ kernel void softmax_float( constant size_t &src_numel, @@ -122,3 +132,8 @@ kernel void softmax_float( idx += blockDim; } } + + +REDUCE(x + y, fast_sum_float, float) +REDUCE(x * y, fast_mul_float, float) +REDUCE(max(x, y), fast_max_float, float) From f710fab02e911b31ab15af80d42923f6a56317b2 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 11:14:51 +0100 Subject: [PATCH 06/15] Fixing the kernels + launches to make them faster. Cool work by @ivarflakstad Co-authored-by: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> --- candle-metal-kernels/src/affine.metal | 13 +-- candle-metal-kernels/src/binary.metal | 30 +++--- candle-metal-kernels/src/cast.metal | 25 ++--- candle-metal-kernels/src/indexing.metal | 14 +-- candle-metal-kernels/src/lib.rs | 125 ++++++------------------ candle-metal-kernels/src/unary.metal | 24 ++--- 6 files changed, 69 insertions(+), 162 deletions(-) diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index c388c04e..e5f0a841 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -24,17 +24,14 @@ kernel void FN_NAME( \ constant float &add, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint id [[ thread_position_in_grid ]] \ ) { \ + if (id >= dim) { \ + return; \ + } \ const TYPENAME m = TYPENAME(mul); \ const TYPENAME a = TYPENAME(add); \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = input[i] * m + a; \ - } \ + output[id] = input[id] * m + a; \ } \ AFFINE(affine_float, float) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index cfd34416..37bc0bae 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -23,17 +23,14 @@ kernel void FN_NAME( \ device const TYPENAME *left, \ device const TYPENAME *right, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - TYPENAME x = left[i]; \ - TYPENAME y = right[i]; \ - output[i] = OUT_TYPENAME(FN); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + TYPENAME x = left[thread_position_in_grid]; \ + TYPENAME y = right[thread_position_in_grid]; \ + output[thread_position_in_grid] = OUT_TYPENAME(FN); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -44,17 +41,14 @@ kernel void FN_NAME_STRIDED( \ device const TYPENAME *left, \ device const TYPENAME *right, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \ - TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \ - output[i] = OUT_TYPENAME(FN); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ + output[thread_position_in_grid] = OUT_TYPENAME(FN); \ } #define BINARY_OP(FN, NAME) \ diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 52e63662..d1788253 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -23,15 +23,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = RIGHT_TYPENAME(input[i]); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -40,17 +37,13 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint i [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + if (i >= dim) { \ + return; \ } \ -} - + output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ +} \ CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index c077cc48..444fa322 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -2,7 +2,7 @@ using namespace metal; # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ -kernel void NAME( \ +kernel void NAME( \ constant size_t &dst_size, \ constant size_t &left_size, \ constant size_t &src_dim_size, \ @@ -42,12 +42,9 @@ void index_add( constant uint &dst_dim_size, constant uint &right_size, - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], - uint thread_index [[thread_index_in_threadgroup]] + uint gid [[ thread_position_in_grid ]] \ ) { - const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size); if (gid >= left_size * right_size) { return; } @@ -73,14 +70,13 @@ kernel void FN_NAME( \ constant uint &left_size, \ constant uint &dst_dim_size, \ constant uint &right_size, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ - uint thread_index [[thread_index_in_threadgroup]] \ -) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ + uint gid [[ thread_position_in_grid ]] \ +) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ INDEX_OP(is_u32_f32, uint, float) + #if __METAL_VERSION__ >= 310 IA_OP(bfloat, int64_t, ia_i64_bf16) IA_OP(bfloat, uint32_t, ia_u32_bf16) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6a01107c..83fbe833 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - Device, Function, Library, MTLSize, + ComputePipelineState, Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -15,6 +15,24 @@ const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { + let size = length as u64; + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); + let count = (size + width - 1) / width; + let thread_group_count = MTLSize { + width: count, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + (thread_group_count, thread_group_size) +} + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } @@ -257,19 +275,7 @@ pub fn call_unary_contiguous( set_params!(encoder, (length, input, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) @@ -314,17 +320,7 @@ pub fn call_unary_strided( ); let width: usize = shape.iter().product(); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -358,18 +354,7 @@ pub fn call_binary_contiguous( set_params!(encoder, (length, left, right, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -421,17 +406,7 @@ pub fn call_binary_strided( ) ); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -464,18 +439,7 @@ pub fn call_cast_contiguous( set_params!(encoder, (length, input, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -608,19 +572,7 @@ pub fn call_affine( set_params!(encoder, (size, mul, add, input, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) @@ -672,18 +624,7 @@ pub fn call_where_cond_strided( ) ); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -730,19 +671,9 @@ pub fn call_index_select( ) ); - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); - let grid_size = MTLSize { - width: (dst_el as u64 + width - 1) / width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 77de214e..dd137599 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -27,15 +27,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = TYPENAME(FN(input[i])); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -44,15 +41,12 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ } #define UNARY_OP(NAME) \ @@ -79,4 +73,6 @@ BFLOAT_UNARY_OP(sqr) BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) + +UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif From d46670f7c0a64ae3824546364b05614e49ecb70a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 15:35:46 +0100 Subject: [PATCH 07/15] Tmp state. --- candle-core/src/device.rs | 2 +- candle-core/src/display.rs | 8 +- candle-core/src/metal_backend.rs | 157 ++++++++++------ candle-core/src/tensor.rs | 8 +- candle-examples/examples/llama2-c/main.rs | 6 +- candle-metal-kernels/Cargo.toml | 1 + candle-metal-kernels/examples/affine.rs | 75 ++++++++ candle-metal-kernels/examples/binary.rs | 182 +++++++++++++++++++ candle-metal-kernels/examples/cast.rs | 84 +++++++++ candle-metal-kernels/examples/unary.rs | 197 +++++++++++++++++++++ candle-metal-kernels/src/binary.metal | 2 +- candle-metal-kernels/src/lib.rs | 36 ++++ candle-nn/src/embedding.rs | 1 + candle-transformers/src/models/llama2_c.rs | 3 + 14 files changed, 699 insertions(+), 63 deletions(-) create mode 100644 candle-metal-kernels/examples/affine.rs create mode 100644 candle-metal-kernels/examples/binary.rs create mode 100644 candle-metal-kernels/examples/cast.rs create mode 100644 candle-metal-kernels/examples/unary.rs diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 73eb9640..3eb7f8b7 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -8,7 +8,7 @@ use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType}; pub enum DeviceLocation { Cpu, Cuda { gpu_id: usize }, - Metal, + Metal { gpu_id: usize }, } #[derive(Debug, Clone)] diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 215c28f6..4f5a390e 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -14,7 +14,9 @@ impl Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } - _ => todo!(), + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } }; write!(f, "Tensor[")?; @@ -477,7 +479,9 @@ impl std::fmt::Display for Tensor { crate::DeviceLocation::Cuda { gpu_id } => { format!(", cuda:{}", gpu_id) } - crate::DeviceLocation::Metal => todo!(), + crate::DeviceLocation::Metal { gpu_id } => { + format!(", metal:{}", gpu_id) + } }; write!( diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index ed592240..6687534d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -100,11 +100,30 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + // TODO Is this necessary + // self.buffer.synchronize(); match self.dtype { + DType::U8 => Ok(CpuStorage::U8( + self.buffer.read_to_vec(self.buffer.length() as usize / 1), + )), + DType::U32 => Ok(CpuStorage::U32( + self.buffer.read_to_vec(self.buffer.length() as usize / 4), + )), + DType::I64 => Ok(CpuStorage::I64( + self.buffer.read_to_vec(self.buffer.length() as usize / 8), + )), + DType::F16 => Ok(CpuStorage::F16( + self.buffer.read_to_vec(self.buffer.length() as usize / 2), + )), + DType::BF16 => Ok(CpuStorage::BF16( + self.buffer.read_to_vec(self.buffer.length() as usize / 2), + )), DType::F32 => Ok(CpuStorage::F32( self.buffer.read_to_vec(self.buffer.length() as usize / 4), )), - dtype => todo!("Unsupported dtype {dtype:?}"), + DType::F64 => Ok(CpuStorage::F64( + self.buffer.read_to_vec(self.buffer.length() as usize / 8), + )), } } @@ -132,6 +151,7 @@ impl BackendStorage for MetalStorage { ) .unwrap(); command_buffer.commit(); + command_buffer.wait_until_completed(); return Ok(Self { buffer, device: device.clone(), @@ -200,6 +220,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, @@ -242,6 +263,7 @@ impl BackendStorage for MetalStorage { } command_buffer.commit(); + command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled(); // debug!( // "cast {:?} - {:?} - {:?}", @@ -289,6 +311,7 @@ impl BackendStorage for MetalStorage { todo!("TODO Implement the kernel calling {}", B::KERNEL); } command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, @@ -361,6 +384,7 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, @@ -400,6 +424,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, device, @@ -489,6 +514,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let device = self.device(); let mut buffer = device.new_buffer(dst_el, dtype); + let out = self.to_cpu_storage().unwrap(); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", (left, right) => todo!("index select metal {left:?} {right:?}"), @@ -508,6 +534,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); Ok(Self { buffer, device: device.clone(), @@ -556,39 +583,42 @@ impl BackendStorage for MetalStorage { if el_count == 0 { return Ok(()); } - if src_l.is_contiguous() { - let command_buffer = self.device.command_queue.new_command_buffer(); - let blip = command_buffer.new_blit_command_encoder(); - blip.copy_from_buffer( - &self.buffer, - src_l.start_offset() as u64, - &dst.buffer, - dst_offset as u64, - self.buffer.length(), - ); - } else { - let command_buffer = self.device.command_queue.new_command_buffer(); - let kernel_name = match self.dtype { - DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, - DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, - DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => todo!("copy_strided not implemented for {dtype:?}"), - }; - candle_metal_kernels::call_unary_strided( - &self.device.device, - &command_buffer, - &self.device.kernels, - kernel_name, - src_l.dims(), - &self.buffer, - &src_l.stride(), - src_l.start_offset(), - &mut dst.buffer, - dst_offset, - ) - .map_err(MetalError::from)?; - command_buffer.commit(); - } + // todo!("Copy strided {:?}", src_l.is_contiguous()); + // if src_l.is_contiguous() { + // let command_buffer = self.device.command_queue.new_command_buffer(); + // let blip = command_buffer.new_blit_command_encoder(); + // blip.copy_from_buffer( + // &self.buffer, + // src_l.start_offset() as u64, + // &dst.buffer, + // dst_offset as u64, + // self.buffer.length(), + // ); + // } else { + let command_buffer = self.device.command_queue.new_command_buffer(); + let kernel_name = match self.dtype { + DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, + DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, + DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + dtype => todo!("copy_strided not implemented for {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &self.device.device, + &command_buffer, + &self.device.kernels, + kernel_name, + src_l.dims(), + &self.buffer, + &src_l.stride(), + src_l.start_offset(), + &mut dst.buffer, + dst_offset, + ) + .map_err(MetalError::from)?; + command_buffer.commit(); + command_buffer.wait_until_completed(); + // todo!("Output {:?}", dst.buffer.read_to_vec::(10)); + // } Ok(()) } } @@ -616,28 +646,29 @@ impl MetalStorage { match (self.dtype, rhs.dtype) { (DType::F32, DType::F32) => { let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); - if b != 1 { - // debug!("TODO implement batched matmul for B={b}"); - // bail!("Didn't implemented strided matmul yet"); - return Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }); - } - if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - // debug!( - // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", - // lhs_l.is_contiguous(), - // rhs_l.is_contiguous(), - // rhs_l - // ); - return Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }); - } + // if b != 1 { + // // debug!("TODO implement batched matmul for B={b}"); + // crate::bail!("Didn't implemented strided matmul yet"); + // return Ok(Self { + // buffer: out_buffer, + // device: self.device.clone(), + // dtype: self.dtype(), + // }); + //} + // if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { + // // debug!( + // // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", + // // lhs_l.is_contiguous(), + // // rhs_l.is_contiguous(), + // // rhs_l + // // ); + // crate::bail!("No not contiguous matmul"); + // return Ok(Self { + // buffer: out_buffer, + // device: self.device.clone(), + // dtype: self.dtype(), + // }); + // } // debug!("TODO GEMM"); let command_buffer = self.device.command_queue.new_command_buffer(); @@ -659,7 +690,15 @@ impl MetalStorage { .map_err(MetalError::from)?; command_buffer.commit(); + command_buffer.wait_until_completed(); // command_buffer.wait_until_scheduled(); + // + let left = self.buffer.read_to_vec::(10); + let right = rhs.buffer.read_to_vec::(10); + let out = out_buffer.read_to_vec::(10); + + println!("{b} {m} {n} {k} "); + println!("{left:?} {right:?} {out:?}"); Ok(Self { buffer: out_buffer, @@ -709,7 +748,9 @@ impl BackendDevice for MetalDevice { } fn location(&self) -> crate::DeviceLocation { - crate::DeviceLocation::Metal + crate::DeviceLocation::Metal { + gpu_id: self.registry_id() as usize, + } } fn same_device(&self, rhs: &Self) -> bool { @@ -767,6 +808,8 @@ impl BackendDevice for MetalDevice { option, ), }; + // TODO is that necessary ? + // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); // debug!("Allocate 2 - buffer size {}", buffer.length()); Ok(Self::Storage { buffer, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f7f66668..3965a2ed 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -157,6 +157,8 @@ pub(crate) fn from_storage>( ) -> Tensor { let dtype = storage.dtype(); let device = storage.device(); + let shape = shape.into(); + // println!("{:?} {storage:?}", shape); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(RwLock::new(storage)), @@ -166,7 +168,11 @@ pub(crate) fn from_storage>( dtype, device, }; - Tensor(Arc::new(tensor_)) + let result = Tensor(Arc::new(tensor_)); + // todo!(" from_storage"); + // let result = result.to_device(&Device::Cpu).unwrap(); + // todo!(" {result}"); + result } impl Tensor { diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 0ceb27af..11381fbc 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -329,14 +329,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .get_ids() .to_vec(); + println!("{tokens:?}"); + let start_gen = std::time::Instant::now(); - for index in 0.. { + for index in 0..1 { if tokens.len() >= config.seq_len { break; } let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; + // println!("Input {}", input); + // println!("Input {}", input.to_device(&candle::Device::Cpu)?); let logits = model.forward(&input, index_pos)?; let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index ff5ede1a..2585ca62 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -17,3 +17,4 @@ tracing = "0.1.37" [dev-dependencies] half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } +rand = "0.8.5" diff --git a/candle-metal-kernels/examples/affine.rs b/candle-metal-kernels/examples/affine.rs new file mode 100644 index 00000000..b8005dc0 --- /dev/null +++ b/candle-metal-kernels/examples/affine.rs @@ -0,0 +1,75 @@ +use candle_metal_kernels::{call_affine, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_affine_bench(&device, &kernels, &f32_1k); + run_affine_bench(&device, &kernels, &f32_10k); + run_affine_bench(&device, &kernels, &f32_100k); +} + +fn run_affine_bench(device: &Device, kernels: &Kernels, v: &[T]) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + let mul: f32 = 1.2345; + let add: f32 = 2.3456; + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_affine( + &device, + command_buffer, + &kernels, + v.len(), + &input, + &mut output, + mul, + add, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + "affine", + v.len(), + iterations, + total_time, + total_time / iterations + ); +} diff --git a/candle-metal-kernels/examples/binary.rs b/candle-metal-kernels/examples/binary.rs new file mode 100644 index 00000000..af5a8bdc --- /dev/null +++ b/candle-metal-kernels/examples/binary.rs @@ -0,0 +1,182 @@ +use candle_metal_kernels::{binary, call_binary_contiguous, call_binary_strided, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + binary::contiguous::add::FLOAT, + binary::contiguous::sub::FLOAT, + binary::contiguous::mul::FLOAT, + binary::contiguous::div::FLOAT, + ]; + let f32_skernels = [ + binary::strided::add::FLOAT, + binary::strided::sub::FLOAT, + binary::strided::mul::FLOAT, + binary::strided::div::FLOAT, + ]; + let f16_ckernels = [ + binary::contiguous::add::HALF, + binary::contiguous::sub::HALF, + binary::contiguous::mul::HALF, + binary::contiguous::div::HALF, + ]; + let f16_skernels = [ + binary::strided::add::HALF, + binary::strided::sub::HALF, + binary::strided::mul::HALF, + binary::strided::div::HALF, + ]; + let bf16_ckernels = [ + binary::contiguous::add::BFLOAT, + binary::contiguous::sub::BFLOAT, + binary::contiguous::mul::BFLOAT, + binary::contiguous::div::BFLOAT, + ]; + let bf16_skernels = [ + binary::strided::add::BFLOAT, + binary::strided::sub::BFLOAT, + binary::strided::mul::BFLOAT, + binary::strided::div::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_binary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_binary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_binary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_binary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_binary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_binary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_binary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [binary::contiguous::Kernel; 4], + strided: [binary::strided::Kernel; 4], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_binary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &input, + &strides, + offset, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/examples/cast.rs b/candle-metal-kernels/examples/cast.rs new file mode 100644 index 00000000..090f510d --- /dev/null +++ b/candle-metal-kernels/examples/cast.rs @@ -0,0 +1,84 @@ +use candle_metal_kernels::{call_cast_contiguous, Kernels}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let contiguous_kernels = ["cast_u32_f32"]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_cast_bench(&device, &kernels, &f32_1k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_10k, &contiguous_kernels); + run_cast_bench(&device, &kernels, &f32_100k, &contiguous_kernels); +} + +fn run_cast_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: &[&'static str], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 1000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_cast_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided? +} diff --git a/candle-metal-kernels/examples/unary.rs b/candle-metal-kernels/examples/unary.rs new file mode 100644 index 00000000..7039c098 --- /dev/null +++ b/candle-metal-kernels/examples/unary.rs @@ -0,0 +1,197 @@ +use candle_metal_kernels::{call_unary_contiguous, call_unary_strided, unary, Kernels}; +use half::{bf16, f16}; +use metal::objc::rc::autoreleasepool; +use metal::{Device, MTLResourceOptions}; +use rand; +use std::any::type_name; +use std::time::Instant; + +fn main() { + let device = Device::system_default().unwrap(); + let kernels = Kernels::new(); + + let f32_1k = (0..1000).map(|_| rand::random::()).collect::>(); + let f32_10k = (0..10000) + .map(|_| rand::random::()) + .collect::>(); + let f32_100k = (0..100000) + .map(|_| rand::random::()) + .collect::>(); + + let f16_map = |v: &[f32]| v.iter().map(|v| f16::from_f32(*v)).collect::>(); + let f16_1k = f16_map(&f32_1k); + let f16_10k = f16_map(&f32_10k); + let f16_100k = f16_map(&f32_100k); + + let bf16_map = |v: &[f32]| v.iter().map(|v| bf16::from_f32(*v)).collect::>(); + let bf16_1k = bf16_map(&f32_1k); + let bf16_10k = bf16_map(&f32_10k); + let bf16_100k = bf16_map(&f32_100k); + + let f32_ckernels = [ + unary::contiguous::sin::FLOAT, + unary::contiguous::cos::FLOAT, + unary::contiguous::exp::FLOAT, + unary::contiguous::sqr::FLOAT, + unary::contiguous::sqrt::FLOAT, + unary::contiguous::neg::FLOAT, + unary::contiguous::copy::FLOAT, + ]; + let f32_skernels = [ + unary::strided::sin::FLOAT, + unary::strided::cos::FLOAT, + unary::strided::exp::FLOAT, + unary::strided::sqr::FLOAT, + unary::strided::sqrt::FLOAT, + unary::strided::neg::FLOAT, + unary::strided::copy::FLOAT, + ]; + let f16_ckernels = [ + unary::contiguous::sin::HALF, + unary::contiguous::cos::HALF, + unary::contiguous::exp::HALF, + unary::contiguous::sqr::HALF, + unary::contiguous::sqrt::HALF, + unary::contiguous::neg::HALF, + unary::contiguous::copy::HALF, + ]; + let f16_skernels = [ + unary::strided::sin::HALF, + unary::strided::cos::HALF, + unary::strided::exp::HALF, + unary::strided::sqr::HALF, + unary::strided::sqrt::HALF, + unary::strided::neg::HALF, + unary::strided::copy::HALF, + ]; + let bf16_ckernels = [ + unary::contiguous::sin::BFLOAT, + unary::contiguous::cos::BFLOAT, + unary::contiguous::exp::BFLOAT, + unary::contiguous::sqr::BFLOAT, + unary::contiguous::sqrt::BFLOAT, + unary::contiguous::neg::BFLOAT, + unary::contiguous::copy::BFLOAT, + ]; + let bf16_skernels = [ + unary::strided::sin::BFLOAT, + unary::strided::cos::BFLOAT, + unary::strided::exp::BFLOAT, + unary::strided::sqr::BFLOAT, + unary::strided::sqrt::BFLOAT, + unary::strided::neg::BFLOAT, + unary::strided::copy::BFLOAT, + ]; + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11} | {5: <11}", + "dtype", "kernel", "size", "runs", "total time", "avg time" + ); + + // f32 + run_unary_bench(&device, &kernels, &f32_1k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_10k, f32_ckernels, f32_skernels); + run_unary_bench(&device, &kernels, &f32_100k, f32_ckernels, f32_skernels); + + // f16 + run_unary_bench(&device, &kernels, &f16_1k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_10k, f16_ckernels, f16_skernels); + run_unary_bench(&device, &kernels, &f16_100k, f16_ckernels, f16_skernels); + + // bf16 + run_unary_bench(&device, &kernels, &bf16_1k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_10k, bf16_ckernels, bf16_skernels); + run_unary_bench(&device, &kernels, &bf16_100k, bf16_ckernels, bf16_skernels); +} + +fn run_unary_bench( + device: &Device, + kernels: &Kernels, + v: &[T], + contiguous: [unary::contiguous::Kernel; 7], + strided: [unary::strided::Kernel; 7], +) { + let command_queue = device.new_command_queue(); + let options = MTLResourceOptions::StorageModeManaged; + + let iterations = 10000; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + core::mem::size_of_val(v) as u64, + options, + ); + let mut output = device.new_buffer(core::mem::size_of_val(v) as u64, options); + + // Contiguous + for kernel_name in contiguous { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_contiguous( + device, + &command_buffer, + kernels, + kernel_name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } + + // Strided + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + for kernel_name in strided { + let total_time = autoreleasepool(|| { + let command_buffer = command_queue.new_command_buffer(); + let start = Instant::now(); + for _ in 0..iterations { + call_unary_strided( + device, + command_buffer, + &kernels, + kernel_name, + &shape, + &input, + &strides, + offset, + &mut output, + 0, + ) + .unwrap(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + + start.elapsed() + }); + + println!( + "{0: <5} | {1: <19} | {2: <6} | {3: <5} | {4: <11?} | {5: <11?}", + type_name::().split("::").last().unwrap(), + kernel_name.to_string(), + v.len(), + iterations, + total_time, + total_time / iterations + ); + } +} diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index 37bc0bae..f18cdbb0 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -47,7 +47,7 @@ kernel void FN_NAME_STRIDED( \ return; \ } \ TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ - TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \ output[thread_position_in_grid] = OUT_TYPENAME(FN); \ } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 83fbe833..e5c9fbae 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -112,7 +112,13 @@ macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { + #[derive(Clone, Copy)] pub struct Kernel(pub(crate) &'static str); + impl std::fmt::Display for Kernel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } $( pub mod $name { use super::Kernel; @@ -124,7 +130,13 @@ macro_rules! ops{ } pub mod strided { + #[derive(Clone, Copy)] pub struct Kernel(pub(crate) &'static str); + impl std::fmt::Display for Kernel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } $( pub mod $name { use super::Kernel; @@ -859,6 +871,30 @@ mod tests { assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); } + #[test] + fn cos_strided_random() { + let v: Vec<_> = (0..10_000).map(|i| rand::random::()).collect(); + let shape = vec![5_000, 2]; + let strides = vec![1, 5_000]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); + assert_eq!( + approx(vec![results[1]], 4), + approx(vec![expected[5_000]], 4) + ); + assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); + assert_eq!( + approx(vec![results[3]], 4), + approx(vec![expected[5_001]], 4) + ); + assert_eq!( + approx(vec![results[5_000]], 4), + approx(vec![expected[2_500]], 4) + ); + } + #[test] fn binary_add_f32() { let left = vec![1.0f32, 2.0, 3.0]; diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index 52968bc2..2daac224 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -9,6 +9,7 @@ pub struct Embedding { impl Embedding { pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { + // todo!("Embedding {embeddings}"); Self { embeddings, hidden_size, diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 753770fb..24182b72 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -165,6 +165,7 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; + todo!("X {q}"); let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; @@ -295,6 +296,7 @@ impl Block { let residual = x; let x = self.rms_1.forward(x)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + todo!("---X {}", x); let residual = &x; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) @@ -327,6 +329,7 @@ impl Llama { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, _seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; + //println!("Embeddings {}", self.wte.embeddings()); for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; } From 38de52bc4b9ad4c0bf59c2d9863409af1e25c541 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 20:09:25 +0100 Subject: [PATCH 08/15] Fixed matmul (display still broken without casting back to CPU first? ) --- Cargo.toml | 3 +- candle-core/src/metal_backend.rs | 229 +++++++++++---------- candle-metal-kernels/Cargo.toml | 3 +- candle-transformers/src/models/llama2_c.rs | 3 +- 4 files changed, 127 insertions(+), 111 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c37bd75b..9c965f94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,8 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../metal-rs", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6687534d..f267165d 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -19,6 +19,13 @@ pub enum MetalError { Message(String), #[error(transparent)] KernelError(#[from] candle_metal_kernels::MetalKernelError), + + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] + MatMulNonContiguous { + lhs_stride: Vec, + rhs_stride: Vec, + mnk: (usize, usize, usize), + }, } impl From for MetalError { @@ -53,7 +60,7 @@ impl MetalDevice { // self.device.as_ref() // } - pub fn id(&self) -> u64 { + pub fn id(&self) -> NSUInteger { self.registry_id() } @@ -70,7 +77,7 @@ impl MetalDevice { } pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { - let size = (element_count * dtype.size_in_bytes()) as u64; + let size = (element_count * dtype.size_in_bytes()) as NSUInteger; // debug!("Allocate 1 - buffer size {size}"); self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) @@ -561,20 +568,116 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - let transpose_left = false; - let transpose_right = !rhs_l.is_contiguous(); - let alpha = 1.0; - let beta = 0.0; - self.matmul_generic( - rhs, - (b, m, n, k), - lhs_l, - rhs_l, + // Create descriptors + use metal::mps::matrix::*; + let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; + let size = core::mem::size_of::() as NSUInteger; + + let elem_count = b * m * n; + + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + // The a tensor has dims batching, k, n (rhs) + let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + Err(MetalError::MatMulNonContiguous { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + mnk: (m, n, k), + })? + }; + + let b = b as NSUInteger; + let m = m as NSUInteger; + let n = n as NSUInteger; + let k = k as NSUInteger; + + let left_descriptor = if transpose_left { + MatrixDescriptor::init_single(k, m, m * size, type_id) + } else { + MatrixDescriptor::init_single(m, k, k * size, type_id) + }; + let right_descriptor = if transpose_right { + MatrixDescriptor::init_single(n, k, k * size, type_id) + } else { + MatrixDescriptor::init_single(k, n, n * size, type_id) + }; + let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id); + + // Create matrix objects + let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + let out_buffer = self.device.new_buffer(elem_count, self.dtype); + let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + let alpha = 1.0f64; + let beta = 0.0f64; + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( + &self.device, transpose_left, transpose_right, + m, + n, + k, alpha, beta, ) + .ok_or_else(|| { + MetalError::from("Failed to create matrix multiplication kernel".to_string()) + })?; + + matrix_multiplication.set_batch_size(b); + + // Encode kernel to command buffer + let command_buffer = self.device.command_queue.new_command_buffer(); + matrix_multiplication.encode_to_command_buffer( + command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + // let left = self.buffer.read_to_vec::(10); + // let right = rhs.buffer.read_to_vec::(10); + // let out = out_buffer.read_to_vec::(40); + // todo!("Out {left:?} {right:?} {out:?}"); + + Ok(Self { + buffer: out_buffer, + device: self.device.clone(), + dtype: self.dtype(), + }) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { @@ -583,18 +686,6 @@ impl BackendStorage for MetalStorage { if el_count == 0 { return Ok(()); } - // todo!("Copy strided {:?}", src_l.is_contiguous()); - // if src_l.is_contiguous() { - // let command_buffer = self.device.command_queue.new_command_buffer(); - // let blip = command_buffer.new_blit_command_encoder(); - // blip.copy_from_buffer( - // &self.buffer, - // src_l.start_offset() as u64, - // &dst.buffer, - // dst_offset as u64, - // self.buffer.length(), - // ); - // } else { let command_buffer = self.device.command_queue.new_command_buffer(); let kernel_name = match self.dtype { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, @@ -631,84 +722,6 @@ impl MetalStorage { dtype, } } - pub(crate) fn matmul_generic( - &self, - rhs: &Self, - (b, m, n, k): (usize, usize, usize, usize), - lhs_l: &Layout, - rhs_l: &Layout, - transpose_left: bool, - transpose_right: bool, - alpha: f64, - beta: f64, - ) -> Result { - let elem_count = b * m * n; - match (self.dtype, rhs.dtype) { - (DType::F32, DType::F32) => { - let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); - // if b != 1 { - // // debug!("TODO implement batched matmul for B={b}"); - // crate::bail!("Didn't implemented strided matmul yet"); - // return Ok(Self { - // buffer: out_buffer, - // device: self.device.clone(), - // dtype: self.dtype(), - // }); - //} - // if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - // // debug!( - // // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", - // // lhs_l.is_contiguous(), - // // rhs_l.is_contiguous(), - // // rhs_l - // // ); - // crate::bail!("No not contiguous matmul"); - // return Ok(Self { - // buffer: out_buffer, - // device: self.device.clone(), - // dtype: self.dtype(), - // }); - // } - - // debug!("TODO GEMM"); - let command_buffer = self.device.command_queue.new_command_buffer(); - encode_gemm::( - &self.device, - &command_buffer, - transpose_left, - transpose_right, - &self.buffer, - &rhs.buffer, - &mut out_buffer, - m as NSUInteger, - n as NSUInteger, - k as NSUInteger, - alpha as f32, - beta as f32, - Some(b as NSUInteger), - ) - .map_err(MetalError::from)?; - - command_buffer.commit(); - command_buffer.wait_until_completed(); - // command_buffer.wait_until_scheduled(); - // - let left = self.buffer.read_to_vec::(10); - let right = rhs.buffer.read_to_vec::(10); - let out = out_buffer.read_to_vec::(10); - - println!("{b} {m} {n} {k} "); - println!("{left:?} {right:?} {out:?}"); - - Ok(Self { - buffer: out_buffer, - device: self.device.clone(), - dtype: self.dtype(), - }) - } - _ => todo!("Unimplemented matmul for this pair"), - } - } pub fn buffer(&self) -> &Buffer { &self.buffer @@ -774,37 +787,37 @@ impl BackendDevice for MetalDevice { let buffer = match storage { CpuStorage::U8(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::U32(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::I64(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::BF16(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::F16(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::F32(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), CpuStorage::F64(storage) => self.device.new_buffer_with_data( storage.as_ptr() as *const core::ffi::c_void, - (storage.len() * mem::size_of::()) as u64, + (storage.len() * mem::size_of::()) as NSUInteger, option, ), }; diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 2585ca62..2d2742ab 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,7 +10,8 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { path = "../../metal-rs", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 24182b72..aba9a547 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -156,6 +156,7 @@ impl CausalSelfAttention { let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; let x0 = x.narrow(D::Minus1, 0, 1)?; let x1 = x.narrow(D::Minus1, 1, 1)?; + todo!("X {x1}"); let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; @@ -165,7 +166,6 @@ impl CausalSelfAttention { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; - todo!("X {q}"); let k = self.k_proj.forward(x)?; let v = self.v_proj.forward(x)?; @@ -174,6 +174,7 @@ impl CausalSelfAttention { let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let q = self.apply_rotary_emb(&q, index_pos)?; + todo!("X {q}"); let mut k = self.apply_rotary_emb(&k, index_pos)?; if self.cache.use_kv_cache { From 7cfffcac10ce076667c76693a6b19f7231652c41 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 20:25:24 +0100 Subject: [PATCH 09/15] Debugging rope. --- Cargo.toml | 3 +-- candle-core/src/metal_backend.rs | 2 -- candle-metal-kernels/Cargo.toml | 3 +-- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9c965f94..c37bd75b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,8 +61,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } -metal = { path = "../metal-rs", features = ["mps"] } +metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f267165d..3f58bb9b 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -7,8 +7,6 @@ use candle_metal_kernels::Kernels; use core::mem; use half::{bf16, f16}; use metal; -use metal::mps::matrix::encode_gemm; -use metal::mps::Float32; use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::sync::Arc; diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 2d2742ab..2585ca62 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -10,8 +10,7 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } -metal = { path = "../../metal-rs", features = ["mps"] } +metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" From 2813fb5dbc404db927dab20b59ef3f2b9dbfc389 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 10 Nov 2023 23:00:32 +0100 Subject: [PATCH 10/15] Cleanup fixed a few ops removed debugging scaffolding. --- candle-core/src/metal_backend.rs | 55 ++++++++-------------- candle-core/src/tensor.rs | 13 ++--- candle-examples/examples/llama2-c/main.rs | 6 +-- candle-metal-kernels/src/lib.rs | 2 +- candle-metal-kernels/src/unary.metal | 2 + candle-nn/src/embedding.rs | 1 - candle-transformers/src/models/llama2_c.rs | 4 -- 7 files changed, 28 insertions(+), 55 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 3f58bb9b..597c2f01 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -105,8 +105,6 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - // TODO Is this necessary - // self.buffer.synchronize(); match self.dtype { DType::U8 => Ok(CpuStorage::U8( self.buffer.read_to_vec(self.buffer.length() as usize / 1), @@ -140,6 +138,7 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; assert!(layout.is_contiguous()); + assert!(layout.start_offset() == 0); assert_eq!(dtype, DType::F32); let mut buffer = device.new_buffer(el, self.dtype); @@ -173,10 +172,10 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - // debug!("TODO reduce_op {op:?} {sum_dims:?}"); assert!(sum_dims.len() == 1); assert!(sum_dims[0] == layout.shape().rank() - 1); assert!(layout.is_contiguous()); + assert!(layout.start_offset() == 0); let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -269,13 +268,6 @@ impl BackendStorage for MetalStorage { command_buffer.commit(); command_buffer.wait_until_completed(); - // command_buffer.wait_until_scheduled(); - // debug!( - // "cast {:?} - {:?} - {:?}", - // dtype, - // self.buffer.length(), - // buffer.length() - // ); Ok(Self { buffer, device: device.clone(), @@ -290,7 +282,7 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); - if layout.is_contiguous() { + if layout.is_contiguous() && layout.start_offset() == 0 { use candle_metal_kernels::unary::contiguous; let kernel_name = match (B::KERNEL, dtype) { @@ -300,6 +292,7 @@ impl BackendStorage for MetalStorage { ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("ulog", DType::F32) => contiguous::log::FLOAT, (name, dtype) => todo!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -337,7 +330,9 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); - if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0) + && (rhs_l.is_contiguous() && rhs_l.start_offset() == 0) + { use candle_metal_kernels::binary::contiguous; let kernel_name = match (B::KERNEL, dtype) { @@ -380,10 +375,10 @@ impl BackendStorage for MetalStorage { lhs_l.dims(), &self.buffer, &lhs_l.stride(), - lhs_l.start_offset(), + lhs_l.start_offset() * self.dtype.size_in_bytes(), &rhs.buffer, &rhs_l.stride(), - rhs_l.start_offset(), + rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &mut buffer, ) .map_err(MetalError::from)?; @@ -420,11 +415,14 @@ impl BackendStorage for MetalStorage { "where_u8_f32", &dims, &self.buffer, - (layout.stride(), layout.start_offset()), + ( + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + ), &t.buffer, - (&t_l.stride(), t_l.start_offset()), + (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), &f.buffer, - (&f_l.stride(), f_l.start_offset()), + (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), &mut buffer, ) .map_err(MetalError::from)?; @@ -511,7 +509,9 @@ impl BackendStorage for MetalStorage { fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { assert!(src_l.is_contiguous()); + assert!(src_l.start_offset() == 0); assert!(ids_l.is_contiguous()); + assert!(ids_l.start_offset() == 0); let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let ids_el = ids_l.shape().elem_count(); @@ -681,6 +681,7 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let el_count = src_shape.elem_count(); + // todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}"); if el_count == 0 { return Ok(()); } @@ -699,15 +700,13 @@ impl BackendStorage for MetalStorage { src_l.dims(), &self.buffer, &src_l.stride(), - src_l.start_offset(), + src_l.start_offset() * self.dtype.size_in_bytes(), &mut dst.buffer, dst_offset, ) .map_err(MetalError::from)?; command_buffer.commit(); command_buffer.wait_until_completed(); - // todo!("Output {:?}", dst.buffer.read_to_vec::(10)); - // } Ok(()) } } @@ -732,24 +731,11 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); - // let capture = metal::CaptureManager::shared(); - // let descriptor = metal::CaptureDescriptor::new(); - // descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - // descriptor.set_capture_device(&device); - // let mut dir = std::env::current_dir()?; - // dir.push("out.gputrace"); - // descriptor.set_output_url(dir); - - // capture - // .start_capture(&descriptor) - // .map_err(MetalError::from)?; let command_queue = device.new_command_queue(); - // let command_buffer = _command_queue.new_owned_command_buffer(); let kernels = Arc::new(Kernels::new()); Ok(Self { device, command_queue, - // command_buffer, kernels, }) } @@ -819,9 +805,6 @@ impl BackendDevice for MetalDevice { option, ), }; - // TODO is that necessary ? - // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); - // debug!("Allocate 2 - buffer size {}", buffer.length()); Ok(Self::Storage { buffer, device: self.clone(), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 3965a2ed..ce5858fa 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -157,8 +157,6 @@ pub(crate) fn from_storage>( ) -> Tensor { let dtype = storage.dtype(); let device = storage.device(); - let shape = shape.into(); - // println!("{:?} {storage:?}", shape); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(RwLock::new(storage)), @@ -168,11 +166,7 @@ pub(crate) fn from_storage>( dtype, device, }; - let result = Tensor(Arc::new(tensor_)); - // todo!(" from_storage"); - // let result = result.to_device(&Device::Cpu).unwrap(); - // todo!(" {result}"); - result + Tensor(Arc::new(tensor_)) } impl Tensor { @@ -1869,7 +1863,10 @@ impl Tensor { Storage::Metal(metal.storage_from_cpu_storage(storage)?) } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), - (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Metal(storage), Device::Cpu) => { + println!("{storage:?} - {:?}", storage.to_cpu_storage()?); + Storage::Cpu(storage.to_cpu_storage()?) + } (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 11381fbc..0ceb27af 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -329,18 +329,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .get_ids() .to_vec(); - println!("{tokens:?}"); - let start_gen = std::time::Instant::now(); - for index in 0..1 { + for index in 0.. { if tokens.len() >= config.seq_len { break; } let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - // println!("Input {}", input); - // println!("Input {}", input.to_device(&candle::Device::Cpu)?); let logits = model.forward(&input, index_pos)?; let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e5c9fbae..7288216a 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -150,7 +150,7 @@ macro_rules! ops{ } pub mod unary { - ops!(cos, sin, exp, sqr, sqrt, neg, copy); + ops!(cos, sin, exp, sqr, sqrt, neg, copy, log); } pub mod binary { ops!(add, sub, mul, div); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index dd137599..eb6424e8 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -63,6 +63,7 @@ UNARY_OP(sqr) UNARY_OP(sqrt) UNARY_OP(neg) UNARY_OP(exp) +UNARY_OP(log) UNARY(id, float, copy_float, copy_float_strided) UNARY(id, half, copy_half, copy_half_strided) @@ -73,6 +74,7 @@ BFLOAT_UNARY_OP(sqr) BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) +BFLOAT_UNARY_OP(log) UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index 2daac224..52968bc2 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -9,7 +9,6 @@ pub struct Embedding { impl Embedding { pub fn new(embeddings: Tensor, hidden_size: usize) -> Self { - // todo!("Embedding {embeddings}"); Self { embeddings, hidden_size, diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index aba9a547..753770fb 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -156,7 +156,6 @@ impl CausalSelfAttention { let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?; let x0 = x.narrow(D::Minus1, 0, 1)?; let x1 = x.narrow(D::Minus1, 1, 1)?; - todo!("X {x1}"); let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?; let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?; let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?; @@ -174,7 +173,6 @@ impl CausalSelfAttention { let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let q = self.apply_rotary_emb(&q, index_pos)?; - todo!("X {q}"); let mut k = self.apply_rotary_emb(&k, index_pos)?; if self.cache.use_kv_cache { @@ -297,7 +295,6 @@ impl Block { let residual = x; let x = self.rms_1.forward(x)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; - todo!("---X {}", x); let residual = &x; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) @@ -330,7 +327,6 @@ impl Llama { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { let (_b_sz, _seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; - //println!("Embeddings {}", self.wte.embeddings()); for (block_idx, block) in self.blocks.iter().enumerate() { x = block.forward(&x, index_pos, block_idx)?; } From bd3b24372593796fbe4039ec4cc1fca04394eb16 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 13 Nov 2023 18:53:16 +0100 Subject: [PATCH 11/15] Update candle-metal-kernels/Cargo.toml --- candle-metal-kernels/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index 2585ca62..f164dc2f 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -3,7 +3,7 @@ name = "candle-metal-kernels" version = "0.3.0" edition = "2021" -description = "CUDA kernels for Candle" +description = "Metal kernels for Candle" repository = "https://github.com/huggingface/candle" keywords = ["blas", "tensor", "machine-learning"] categories = ["science"] From c66e5d4716f957a5b0cc049c88b8eeae51a8fee0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 20 Nov 2023 14:00:39 +0100 Subject: [PATCH 12/15] Fix comments. --- Cargo.toml | 2 +- candle-core/Cargo.toml | 2 +- candle-core/src/metal_backend.rs | 120 +++++++++++++------------------ candle-metal-kernels/Cargo.toml | 4 +- candle-metal-kernels/src/lib.rs | 42 +++-------- 5 files changed, 66 insertions(+), 104 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c37bd75b..ba09b1d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7" wav = "1.0.0" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "0.6.6", default-features = false } -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } [profile.release-with-debug] inherits = "release" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 592f5bdf..e7d3ab6a 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -13,7 +13,7 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true } -candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } +candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true } metal = { workspace = true, optional = true} cudarc = { workspace = true, optional = true } gemm = { workspace = true } diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 597c2f01..27475efe 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -54,10 +54,6 @@ impl std::ops::Deref for MetalDevice { } impl MetalDevice { - // pub fn metal_device(&self) -> &metal::DeviceRef { - // self.device.as_ref() - // } - pub fn id(&self) -> NSUInteger { self.registry_id() } @@ -76,7 +72,6 @@ impl MetalDevice { pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer { let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - // debug!("Allocate 1 - buffer size {size}"); self.device .new_buffer(size, MTLResourceOptions::StorageModeManaged) } @@ -105,28 +100,22 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { + let length = self.buffer.length() as usize; + let size = self.dtype.size_in_bytes(); + if length % size != 0 { + crate::bail!( + "The Metal buffer length is not aligned with dtype {:?}", + self.dtype + ); + } match self.dtype { - DType::U8 => Ok(CpuStorage::U8( - self.buffer.read_to_vec(self.buffer.length() as usize / 1), - )), - DType::U32 => Ok(CpuStorage::U32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), - )), - DType::I64 => Ok(CpuStorage::I64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), - )), - DType::F16 => Ok(CpuStorage::F16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), - )), - DType::BF16 => Ok(CpuStorage::BF16( - self.buffer.read_to_vec(self.buffer.length() as usize / 2), - )), - DType::F32 => Ok(CpuStorage::F32( - self.buffer.read_to_vec(self.buffer.length() as usize / 4), - )), - DType::F64 => Ok(CpuStorage::F64( - self.buffer.read_to_vec(self.buffer.length() as usize / 8), - )), + DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))), + DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))), + DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))), + DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))), + DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))), + DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))), + DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))), } } @@ -137,9 +126,9 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - assert!(layout.is_contiguous()); - assert!(layout.start_offset() == 0); - assert_eq!(dtype, DType::F32); + if layout.is_contiguous() || layout.start_offset() != 0|| dtype != DType::F32{ + crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); + } let mut buffer = device.new_buffer(el, self.dtype); let command_buffer = self.device.command_queue.new_command_buffer(); @@ -153,7 +142,7 @@ impl BackendStorage for MetalStorage { mul as f32, add as f32, ) - .unwrap(); + .map_err(MetalError::from)?; command_buffer.commit(); command_buffer.wait_until_completed(); return Ok(Self { @@ -164,18 +153,18 @@ impl BackendStorage for MetalStorage { } fn powf(&self, _: &Layout, _: f64) -> Result { - todo!() + crate::bail!("powf metal") } fn elu(&self, _: &Layout, _: f64) -> Result { - todo!() + crate::bail!("elu metal") } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - assert!(sum_dims.len() == 1); - assert!(sum_dims[0] == layout.shape().rank() - 1); - assert!(layout.is_contiguous()); - assert!(layout.start_offset() == 0); + + if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 && layout.is_contiguous() && layout.start_offset() == 0){ + crate::bail!("Non contiguous reduce op not supported yet"); + } let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); @@ -204,7 +193,7 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::F32) => ("fast_max_float", true, false), (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true), (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true), - _ => todo!("Reduce op for non float"), + _ => crate::bail!("Reduce op for non float"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? @@ -234,7 +223,7 @@ impl BackendStorage for MetalStorage { } fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result { - todo!() + crate::bail!("cmp metal") } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { @@ -246,7 +235,7 @@ impl BackendStorage for MetalStorage { if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", - (left, right) => todo!("to dtype {left:?} - {right:?}"), + (left, right) => crate::bail!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( &device.device, @@ -259,7 +248,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!( + crate::bail!( "TODO Implement the kernel calling cast {:?}-{:?}", self.dtype, dtype @@ -293,7 +282,7 @@ impl BackendStorage for MetalStorage { ("uneg", DType::F32) => contiguous::neg::FLOAT, ("uexp", DType::F32) => contiguous::exp::FLOAT, ("ulog", DType::F32) => contiguous::log::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( &device.device, @@ -306,7 +295,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!("TODO Implement the kernel calling {}", B::KERNEL); + crate::bail!("TODO Implement the kernel calling {}", B::KERNEL); } command_buffer.commit(); command_buffer.wait_until_completed(); @@ -344,7 +333,7 @@ impl BackendStorage for MetalStorage { ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, ("bdiv", DType::F32) => contiguous::div::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( &device.device, @@ -365,7 +354,7 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, - (name, dtype) => todo!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_strided( &device.device, @@ -442,7 +431,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConv1D, ) -> Result { - todo!() + crate::bail!("conv1d metal") } fn conv_transpose1d( @@ -452,7 +441,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose1D, ) -> Result { - todo!() + crate::bail!("conv_transpose1d metal") } fn conv2d( @@ -462,7 +451,7 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConv2D, ) -> Result { - todo!() + crate::bail!("conv2d metal") } fn conv_transpose2d( @@ -472,27 +461,27 @@ impl BackendStorage for MetalStorage { _kernel_l: &Layout, _params: &ParamsConvTranspose2D, ) -> Result { - todo!() + crate::bail!("conv_tranpose2d metal") } fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - todo!() + crate::bail!("avg_pool2d metal") } fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { - todo!() + crate::bail!("max_pool2d metal") } fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result { - todo!() + crate::bail!("upsample_nearest1d metal") } fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result { - todo!() + crate::bail!("upsample_nearest2d metal") } fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result { - todo!() + crate::bail!("gather metal") } fn scatter_add( @@ -504,14 +493,13 @@ impl BackendStorage for MetalStorage { _: &Layout, _: usize, ) -> Result { - todo!() + crate::bail!("scatter_add metal") } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - assert!(src_l.is_contiguous()); - assert!(src_l.start_offset() == 0); - assert!(ids_l.is_contiguous()); - assert!(ids_l.start_offset() == 0); + if !(src_l.is_contiguous() && src_l.start_offset() == 0 && ids_l.is_contiguous() && ids_l.start_offset() == 0){ + crate::bail!("Non contiguous index select not implemented"); + } let left_size: usize = src_l.dims()[..dim].iter().product(); let right_size: usize = src_l.dims()[dim + 1..].iter().product(); let ids_el = ids_l.shape().elem_count(); @@ -519,10 +507,10 @@ impl BackendStorage for MetalStorage { let dtype = self.dtype; let device = self.device(); let mut buffer = device.new_buffer(dst_el, dtype); - let out = self.to_cpu_storage().unwrap(); + let out = self.to_cpu_storage()?; let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", - (left, right) => todo!("index select metal {left:?} {right:?}"), + (left, right) => crate::bail!("index select metal {left:?} {right:?}"), }; let command_buffer = self.device.command_queue.new_command_buffer(); candle_metal_kernels::call_index_select( @@ -556,7 +544,7 @@ impl BackendStorage for MetalStorage { _: &Layout, _: usize, ) -> Result { - todo!() + crate::bail!("index_add metal") } fn matmul( @@ -666,11 +654,6 @@ impl BackendStorage for MetalStorage { command_buffer.commit(); command_buffer.wait_until_completed(); - // let left = self.buffer.read_to_vec::(10); - // let right = rhs.buffer.read_to_vec::(10); - // let out = out_buffer.read_to_vec::(40); - // todo!("Out {left:?} {right:?} {out:?}"); - Ok(Self { buffer: out_buffer, device: self.device.clone(), @@ -681,7 +664,6 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); let el_count = src_shape.elem_count(); - // todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}"); if el_count == 0 { return Ok(()); } @@ -690,7 +672,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, - dtype => todo!("copy_strided not implemented for {dtype:?}"), + dtype => crate::bail!("copy_strided not implemented for {dtype:?}"), }; candle_metal_kernels::call_unary_strided( &self.device.device, @@ -741,7 +723,7 @@ impl BackendDevice for MetalDevice { } fn set_seed(&self, _seed: u64) -> Result<()> { - todo!("set_seed") + crate::bail!("set_seed") } fn location(&self) -> crate::DeviceLocation { diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index f164dc2f..186f3209 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-metal-kernels" -version = "0.3.0" +version = "0.3.1" edition = "2021" description = "Metal kernels for Candle" @@ -10,7 +10,7 @@ categories = ["science"] license = "MIT OR Apache-2.0" [dependencies] -metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] } +metal = { version = "0.27.1", features = ["mps"], package="candle-metal" } once_cell = "1.18.0" thiserror = "1" tracing = "0.1.37" diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 7288216a..6c2e5f2b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,4 +1,3 @@ -#![allow(clippy::too_many_arguments)] use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, ComputePipelineState, Device, Function, Library, MTLSize, @@ -156,14 +155,6 @@ pub mod binary { ops!(add, sub, mul, div); } -// static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { -// let mut l = HashMap::new(); -// l.insert("affine", AFFINE); -// l.insert("indexing", INDEXING); -// l.insert("unary", UNARY); -// l -// }); -// #[derive(thiserror::Error, Debug)] pub enum MetalKernelError { #[error("Could not lock kernel map: {0}")] @@ -197,21 +188,7 @@ impl Kernels { Self { libraries, funcs } } - // pub fn init(device: &Device) -> Result { - // let kernels = Self::new(); - // kernels.load_libraries(device)?; - // Ok(kernels) - // } - - // fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> { - // for name in LIBRARY_SOURCES.keys() { - // self.load_library(device, name)?; - // } - // Ok(()) - // } - fn get_library_source(&self, source: Source) -> &'static str { - // LIBRARY_SOURCES.get(name).cloned() match source { Source::Affine => AFFINE, Source::Unary => UNARY, @@ -261,6 +238,7 @@ impl Kernels { } } +#[allow(clippy::too_many_arguments)] pub fn call_unary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -270,8 +248,6 @@ pub fn call_unary_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Unary, kernel_name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -292,6 +268,8 @@ pub fn call_unary_contiguous( encoder.end_encoding(); Ok(()) } + +#[allow(clippy::too_many_arguments)] pub fn call_unary_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -339,6 +317,7 @@ pub fn call_unary_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_binary_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -349,8 +328,6 @@ pub fn call_binary_contiguous( right: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -373,6 +350,7 @@ pub fn call_binary_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_binary_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -425,6 +403,7 @@ pub fn call_binary_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_cast_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -434,8 +413,6 @@ pub fn call_cast_contiguous( input: &Buffer, output: &mut Buffer, ) -> Result<(), MetalKernelError> { - // println!("Kernel {:?}", kernel_name.0); - // assert_eq!(input.length(), output.length()); let func = kernels.load_function(device, Source::Cast, kernel_name)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -458,6 +435,7 @@ pub fn call_cast_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -508,6 +486,7 @@ pub fn call_reduce_contiguous( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, command_buffer: &CommandBufferRef, @@ -543,7 +522,6 @@ pub fn call_last_softmax( let width = std::cmp::min( pipeline.max_total_threads_per_threadgroup(), - // (elements_to_sum as u64 + 2 - 1) / 2, elements_to_sum as u64, ) .next_power_of_two(); @@ -559,6 +537,7 @@ pub fn call_last_softmax( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, @@ -590,6 +569,7 @@ pub fn call_affine( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_where_cond_strided( device: &Device, command_buffer: &CommandBufferRef, @@ -643,6 +623,7 @@ pub fn call_where_cond_strided( Ok(()) } +#[allow(clippy::too_many_arguments)] pub fn call_index_select( device: &Device, command_buffer: &CommandBufferRef, @@ -813,7 +794,6 @@ mod tests { #[test] fn cos_f32_strided() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - // Shape = [6], strides = [1]; let shape = vec![6]; let strides = vec![1]; let offset = 0; From dc64adb8e4e2ad64555fc7bbc82e867178ff6cff Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 20 Nov 2023 14:17:07 +0100 Subject: [PATCH 13/15] Fixing cos_f16 test. --- candle-metal-kernels/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6c2e5f2b..cff8e763 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -853,7 +853,7 @@ mod tests { #[test] fn cos_strided_random() { - let v: Vec<_> = (0..10_000).map(|i| rand::random::()).collect(); + let v: Vec<_> = (0..10_000).map(|_| rand::random::()).collect(); let shape = vec![5_000, 2]; let strides = vec![1, 5_000]; let offset = 0; @@ -1114,7 +1114,7 @@ mod tests { .collect(); let results = run(&v, unary::contiguous::cos::HALF); let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]); + assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); } From 671fc29b36f9f9de2679a58a5fe0756b000a5bf7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 20 Nov 2023 14:38:20 +0100 Subject: [PATCH 14/15] Fmt. --- candle-core/src/metal_backend.rs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 27475efe..52cde1b7 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -126,7 +126,7 @@ impl BackendStorage for MetalStorage { let el = shape.elem_count(); let dtype = self.dtype; - if layout.is_contiguous() || layout.start_offset() != 0|| dtype != DType::F32{ + if layout.is_contiguous() || layout.start_offset() != 0 || dtype != DType::F32 { crate::bail!("Not contiguous, non-f32 affine is not implemented yet."); } @@ -161,8 +161,11 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - - if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 && layout.is_contiguous() && layout.start_offset() == 0){ + if !(sum_dims.len() == 1 + && sum_dims[0] == layout.shape().rank() - 1 + && layout.is_contiguous() + && layout.start_offset() == 0) + { crate::bail!("Non contiguous reduce op not supported yet"); } let device = self.device.clone(); @@ -497,7 +500,11 @@ impl BackendStorage for MetalStorage { } fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - if !(src_l.is_contiguous() && src_l.start_offset() == 0 && ids_l.is_contiguous() && ids_l.start_offset() == 0){ + if !(src_l.is_contiguous() + && src_l.start_offset() == 0 + && ids_l.is_contiguous() + && ids_l.start_offset() == 0) + { crate::bail!("Non contiguous index select not implemented"); } let left_size: usize = src_l.dims()[..dim].iter().product(); From 60f624a9029df61ef4af980b910e53b8d09618f0 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 20 Nov 2023 16:17:19 +0100 Subject: [PATCH 15/15] Moving tests around. --- candle-metal-kernels/src/lib.rs | 624 +----------------------------- candle-metal-kernels/src/tests.rs | 616 +++++++++++++++++++++++++++++ 2 files changed, 617 insertions(+), 623 deletions(-) create mode 100644 candle-metal-kernels/src/tests.rs diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index cff8e763..5a6bd41b 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -672,626 +672,4 @@ pub fn call_index_select( } #[cfg(test)] -mod tests { - use super::*; - use half::f16; - use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; - - fn new_buffer(device: &Device, data: &[T]) -> Buffer { - let options = MTLResourceOptions::StorageModeManaged; - let ptr = data.as_ptr() as *const core::ffi::c_void; - let size = (data.len() * std::mem::size_of::()) as u64; - device.new_buffer_with_data(ptr, size, options) - } - - fn device() -> Device { - Device::system_default().unwrap() - } - - fn approx(v: Vec, digits: i32) -> Vec { - let b = 10f32.powi(digits); - v.iter().map(|t| f32::round(t * b) / b).collect() - } - - fn approx_f16(v: Vec, digits: i32) -> Vec { - let b = 10f32.powi(digits); - v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() - } - - fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - call_unary_contiguous( - &device, - command_buffer, - &kernels, - name, - v.len(), - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) - } - - fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let left = new_buffer(&device, x); - let right = new_buffer(&device, y); - let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); - call_binary_contiguous( - &device, - command_buffer, - &kernels, - name, - x.len(), - &left, - &right, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::(x.len()) - } - - fn run_strided( - v: &[T], - kernel: unary::strided::Kernel, - shape: &[usize], - strides: &[usize], - offset: usize, - ) -> Vec { - let device = device(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - let kernels = Kernels::new(); - call_unary_strided( - &device, - command_buffer, - &kernels, - kernel, - shape, - &input, - strides, - offset, - &mut output, - 0, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) - } - - #[test] - fn cos_f32() { - let v = vec![1.0f32, 2.0, 3.0]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); - assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); - - let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); - } - - #[test] - fn cos_f32_strided() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let shape = vec![6]; - let strides = vec![1]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!( - approx(results, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - assert_eq!( - approx(expected, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - - // Contiguous - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let shape = vec![3, 2]; - let strides = vec![2, 1]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!( - approx(results, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - assert_eq!( - approx(expected, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - - // Transposed - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let shape = vec![3, 2]; - let strides = vec![1, 3]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!( - approx(results, 4), - vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] - ); - assert_eq!( - approx(expected, 4), - vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] - ); - - // Very large - let v = vec![1.0f32; 10_000]; - let shape = vec![2, 5_000]; - let strides = vec![2, 1]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); - } - - #[test] - fn cos_strided_random() { - let v: Vec<_> = (0..10_000).map(|_| rand::random::()).collect(); - let shape = vec![5_000, 2]; - let strides = vec![1, 5_000]; - let offset = 0; - let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); - assert_eq!( - approx(vec![results[1]], 4), - approx(vec![expected[5_000]], 4) - ); - assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); - assert_eq!( - approx(vec![results[3]], 4), - approx(vec![expected[5_001]], 4) - ); - assert_eq!( - approx(vec![results[5_000]], 4), - approx(vec![expected[2_500]], 4) - ); - } - - #[test] - fn binary_add_f32() { - let left = vec![1.0f32, 2.0, 3.0]; - let right = vec![2.0f32, 3.1, 4.2]; - let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); - let expected: Vec<_> = left - .iter() - .zip(right.iter()) - .map(|(&x, &y)| x + y) - .collect(); - assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); - assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); - } - - fn cast(v: &[T], name: &'static str) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - - call_cast_contiguous( - &device, - command_buffer, - &kernels, - name, - v.len(), - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - output.read_to_vec::(v.len()) - } - - #[test] - fn cast_u32_f32() { - let v = vec![1u32, 2, 3]; - let results = cast(&v, "cast_u32_f32"); - let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); - assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); - assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); - - let v = vec![1.0f32; 10_000]; - let results = run(&v, unary::contiguous::cos::FLOAT); - let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - assert_eq!(approx(results, 4), vec![0.5403; 10_000]); - assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); - } - - fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - - let size = v.len(); - - call_affine( - &device, - command_buffer, - &kernels, - size, - &input, - &mut output, - mul as f32, - add as f32, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::(v.len()) - } - - #[test] - fn affine() { - let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; - let mul = 1.5; - let add = 1.1; - let result = run_affine(&input, mul, add); - assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); - - let input = [1.0f32; 40_000]; - let mul = 1.5; - let add = 1.1; - let result = run_affine(&input, mul, add); - assert_eq!(result, vec![2.6; 40_000]); - } - - #[test] - fn index_select() { - let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - let shape = [5, 2]; - let ids = [0u32, 4, 2]; - let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); - assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); - - let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - let shape = [2, 5]; - let ids = [0u32, 1, 0]; - let dim = 0; - let result = run_index_select(&embedding, &shape, &ids, dim); - assert_eq!( - result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] - ); - - let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - let shape = [5, 2]; - let ids = [0u32, 1, 0]; - let dim = 1; - let result = run_index_select(&embedding, &shape, &ids, dim); - assert_eq!( - result, - vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] - ); - } - - fn run_index_select( - embeddings: &[T], - shape: &[usize], - ids: &[I], - dim: usize, - ) -> Vec { - let device = Device::system_default().expect("no device found"); - - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let embeddings_buffer = new_buffer(&device, &embeddings); - let ids_buffer = new_buffer(&device, &ids); - - let left_size: usize = shape[..dim].iter().product(); - let right_size: usize = shape[dim + 1..].iter().product(); - let dst_el = ids.len() * left_size * right_size; - let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); - - let kernels = Kernels::new(); - call_index_select( - &device, - &command_buffer, - &kernels, - "is_u32_f32", - shape, - ids.len(), - dim, - &embeddings_buffer, - &ids_buffer, - &mut dst_buffer, - ) - .unwrap(); - - command_buffer.commit(); - command_buffer.wait_until_completed(); - - dst_buffer.read_to_vec::(dst_el) - } - - #[test] - fn index_add() { - let device = Device::system_default().expect("no device found"); - - let options = CompileOptions::new(); - let library = device.new_library_with_source(INDEXING, &options).unwrap(); - - let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; - let right = [1.0f32; 15]; - let index = [0u32, 4, 2]; - let ids_dim_size = index.len() as u32; - let dst_dim_size: u32 = 15; - let left_size: u32 = 3; - let right_size: u32 = 3; - - let function = library.get_function("ia_u32_f32", None).unwrap(); - let pipeline = device - .new_compute_pipeline_state_with_function(&function) - .unwrap(); - - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let encoder = command_buffer.new_compute_command_encoder(); - - encoder.set_compute_pipeline_state(&pipeline); - - let index_buffer = new_buffer(&device, &index); - let inputs_buffer = new_buffer(&device, &left); - let outputs_buffer = new_buffer(&device, &right); - - set_params!( - encoder, - ( - &index_buffer, - &inputs_buffer, - &outputs_buffer, - ids_dim_size, - left_size, - dst_dim_size, - right_size - ) - ); - - let grid_size = MTLSize { - width: right.len() as NSUInteger, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: pipeline.max_total_threads_per_threadgroup(), - height: 1, - depth: 1, - }; - - encoder.dispatch_thread_groups(grid_size, thread_group_size); - encoder.end_encoding(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - let expected = vec![ - 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, - ]; - let result = outputs_buffer.read_to_vec::(right.len()); - assert_eq!(result, expected); - } - - #[test] - fn cos_f16() { - let v: Vec = [1.0f32, 2.0, 3.0] - .iter() - .map(|v| f16::from_f32(*v)) - .collect(); - let results = run(&v, unary::contiguous::cos::HALF); - let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); - assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); - assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); - } - - fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - - let options = MTLResourceOptions::StorageModeManaged; - let mut output = - device.new_buffer((out_length * core::mem::size_of::()) as u64, options); - call_reduce_contiguous( - &device, - command_buffer, - &kernels, - name, - v.len(), - out_length, - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::(out_length) - } - - fn run_softmax( - v: &[T], - last_dim: usize, - name: &'static str, - ) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let input = new_buffer(&device, v); - let mut output = new_buffer(&device, v); - call_last_softmax( - &device, - command_buffer, - &kernels, - name, - v.len(), - last_dim, - &input, - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::(v.len()) - } - - #[test] - fn reduce_sum() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 1; - - let results = run_reduce(&v, out_length, "fast_sum_float"); - assert_eq!(approx(results, 4), vec![21.0]); - } - - #[test] - fn reduce_sum2() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let out_length = 2; - - let results = run_reduce(&v, out_length, "fast_sum_float"); - assert_eq!(approx(results, 4), vec![6.0, 15.0]); - } - - #[test] - fn softmax() { - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); - assert_eq!( - approx(results, 4), - vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] - ); - - let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; - let last_dim = 6; - let results = run_softmax(&v, last_dim, "softmax_float"); - assert_eq!( - approx(results, 4), - vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] - ); - - let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let last_dim = 3; - let results = run_softmax(&v, last_dim, "softmax_float"); - assert_eq!( - approx(results, 4), - vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] - ); - } - - fn run_where_cond( - shape: &[usize], - cond: &[I], - (cond_stride, cond_offset): (Vec, usize), - left_true: &[T], - (left_stride, left_offset): (Vec, usize), - right_false: &[T], - (_right_stride, _right_offset): (Vec, usize), - name: &'static str, - ) -> Vec { - let device = device(); - let kernels = Kernels::new(); - let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - - let length = cond.len(); - let cond = device.new_buffer_with_data( - cond.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(cond) as u64, - options, - ); - let left = device.new_buffer_with_data( - left_true.as_ptr() as *const core::ffi::c_void, - (length * core::mem::size_of::()) as u64, - options, - ); - let right = device.new_buffer_with_data( - right_false.as_ptr() as *const core::ffi::c_void, - (length * core::mem::size_of::()) as u64, - options, - ); - - let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); - call_where_cond_strided( - &device, - command_buffer, - &kernels, - name, - shape, - &cond, - (&cond_stride, cond_offset), - &left, - (&left_stride, left_offset), - &right, - (&cond_stride, cond_offset), - &mut output, - ) - .unwrap(); - command_buffer.commit(); - command_buffer.wait_until_completed(); - - output.read_to_vec::(length) - } - - #[test] - fn where_cond() { - let shape = vec![6]; - let cond = vec![0u8, 1, 0, 0, 1, 1]; - let cond_l = (vec![1], 0); - let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; - let left_l = (vec![1], 0); - let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; - let right_l = (vec![1], 0); - let results = run_where_cond( - &shape, - &cond, - cond_l, - &left_true, - left_l, - &right_false, - right_l, - "where_u8_f32", - ); - assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); - } -} +mod tests; diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs new file mode 100644 index 00000000..2330d48d --- /dev/null +++ b/candle-metal-kernels/src/tests.rs @@ -0,0 +1,616 @@ +use super::*; +use half::f16; +use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; + +fn new_buffer(device: &Device, data: &[T]) -> Buffer { + let options = MTLResourceOptions::StorageModeManaged; + let ptr = data.as_ptr() as *const core::ffi::c_void; + let size = (data.len() * std::mem::size_of::()) as u64; + device.new_buffer_with_data(ptr, size, options) +} + +fn device() -> Device { + Device::system_default().unwrap() +} + +fn approx(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t * b) / b).collect() +} + +fn approx_f16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() +} + +fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_unary_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) +} + +fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let left = new_buffer(&device, x); + let right = new_buffer(&device, y); + let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); + call_binary_contiguous( + &device, + command_buffer, + &kernels, + name, + x.len(), + &left, + &right, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(x.len()) +} + +fn run_strided( + v: &[T], + kernel: unary::strided::Kernel, + shape: &[usize], + strides: &[usize], + offset: usize, +) -> Vec { + let device = device(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + let kernels = Kernels::new(); + call_unary_strided( + &device, + command_buffer, + &kernels, + kernel, + shape, + &input, + strides, + offset, + &mut output, + 0, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) +} + +#[test] +fn cos_f32() { + let v = vec![1.0f32, 2.0, 3.0]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); + assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_f32_strided() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![6]; + let strides = vec![1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Contiguous + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Transposed + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shape = vec![3, 2]; + let strides = vec![1, 3]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!( + approx(results, 4), + vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602] + ); + assert_eq!( + approx(expected, 4), + vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602] + ); + + // Very large + let v = vec![1.0f32; 10_000]; + let shape = vec![2, 5_000]; + let strides = vec![2, 1]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +#[test] +fn cos_strided_random() { + let v: Vec<_> = (0..10_000).map(|_| rand::random::()).collect(); + let shape = vec![5_000, 2]; + let strides = vec![1, 5_000]; + let offset = 0; + let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(vec![results[0]], 4), approx(vec![expected[0]], 4)); + assert_eq!( + approx(vec![results[1]], 4), + approx(vec![expected[5_000]], 4) + ); + assert_eq!(approx(vec![results[2]], 4), approx(vec![expected[1]], 4)); + assert_eq!( + approx(vec![results[3]], 4), + approx(vec![expected[5_001]], 4) + ); + assert_eq!( + approx(vec![results[5_000]], 4), + approx(vec![expected[2_500]], 4) + ); +} + +#[test] +fn binary_add_f32() { + let left = vec![1.0f32, 2.0, 3.0]; + let right = vec![2.0f32, 3.1, 4.2]; + let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); + let expected: Vec<_> = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| x + y) + .collect(); + assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); + assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); +} + +fn cast(v: &[T], name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + + call_cast_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) +} + +#[test] +fn cast_u32_f32() { + let v = vec![1u32, 2, 3]; + let results = cast(&v, "cast_u32_f32"); + let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); + assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); + assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); +} + +fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + + let size = v.len(); + + call_affine( + &device, + command_buffer, + &kernels, + size, + &input, + &mut output, + mul as f32, + add as f32, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) +} + +#[test] +fn affine() { + let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]); + + let input = [1.0f32; 40_000]; + let mul = 1.5; + let add = 1.1; + let result = run_affine(&input, mul, add); + assert_eq!(result, vec![2.6; 40_000]); +} + +#[test] +fn index_select() { + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [2, 5]; + let ids = [0u32, 1, 0]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 1, 0]; + let dim = 1; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); +} + +fn run_index_select( + embeddings: &[T], + shape: &[usize], + ids: &[I], + dim: usize, +) -> Vec { + let device = Device::system_default().expect("no device found"); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let embeddings_buffer = new_buffer(&device, &embeddings); + let ids_buffer = new_buffer(&device, &ids); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let dst_el = ids.len() * left_size * right_size; + let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + let kernels = Kernels::new(); + call_index_select( + &device, + &command_buffer, + &kernels, + "is_u32_f32", + shape, + ids.len(), + dim, + &embeddings_buffer, + &ids_buffer, + &mut dst_buffer, + ) + .unwrap(); + + command_buffer.commit(); + command_buffer.wait_until_completed(); + + dst_buffer.read_to_vec::(dst_el) +} + +#[test] +fn index_add() { + let device = Device::system_default().expect("no device found"); + + let options = CompileOptions::new(); + let library = device.new_library_with_source(INDEXING, &options).unwrap(); + + let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let right = [1.0f32; 15]; + let index = [0u32, 4, 2]; + let ids_dim_size = index.len() as u32; + let dst_dim_size: u32 = 15; + let left_size: u32 = 3; + let right_size: u32 = 3; + + let function = library.get_function("ia_u32_f32", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + let index_buffer = new_buffer(&device, &index); + let inputs_buffer = new_buffer(&device, &left); + let outputs_buffer = new_buffer(&device, &right); + + set_params!( + encoder, + ( + &index_buffer, + &inputs_buffer, + &outputs_buffer, + ids_dim_size, + left_size, + dst_dim_size, + right_size + ) + ); + + let grid_size = MTLSize { + width: right.len() as NSUInteger, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: pipeline.max_total_threads_per_threadgroup(), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = vec![ + 2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0, + ]; + let result = outputs_buffer.read_to_vec::(right.len()); + assert_eq!(result, expected); +} + +#[test] +fn cos_f16() { + let v: Vec = [1.0f32, 2.0, 3.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let results = run(&v, unary::contiguous::cos::HALF); + let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); + assert_eq!(approx_f16(results, 4), vec![0.5405, -0.4163, -0.9902]); + assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); +} + +fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + + let options = MTLResourceOptions::StorageModeManaged; + let mut output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); + call_reduce_contiguous( + &device, + command_buffer, + &kernels, + name, + v.len(), + out_length, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(out_length) +} + +fn run_softmax(v: &[T], last_dim: usize, name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_last_softmax( + &device, + command_buffer, + &kernels, + name, + v.len(), + last_dim, + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(v.len()) +} + +#[test] +fn reduce_sum() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 1; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![21.0]); +} + +#[test] +fn reduce_sum2() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let out_length = 2; + + let results = run_reduce(&v, out_length, "fast_sum_float"); + assert_eq!(approx(results, 4), vec![6.0, 15.0]); +} + +#[test] +fn softmax() { + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]; + let last_dim = 6; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337] + ); + + let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let last_dim = 3; + let results = run_softmax(&v, last_dim, "softmax_float"); + assert_eq!( + approx(results, 4), + vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652] + ); +} + +fn run_where_cond( + shape: &[usize], + cond: &[I], + (cond_stride, cond_offset): (Vec, usize), + left_true: &[T], + (left_stride, left_offset): (Vec, usize), + right_false: &[T], + (_right_stride, _right_offset): (Vec, usize), + name: &'static str, +) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + + let length = cond.len(); + let cond = device.new_buffer_with_data( + cond.as_ptr() as *const core::ffi::c_void, + std::mem::size_of_val(cond) as u64, + options, + ); + let left = device.new_buffer_with_data( + left_true.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + let right = device.new_buffer_with_data( + right_false.as_ptr() as *const core::ffi::c_void, + (length * core::mem::size_of::()) as u64, + options, + ); + + let mut output = device.new_buffer((length * core::mem::size_of::()) as u64, options); + call_where_cond_strided( + &device, + command_buffer, + &kernels, + name, + shape, + &cond, + (&cond_stride, cond_offset), + &left, + (&left_stride, left_offset), + &right, + (&cond_stride, cond_offset), + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + output.read_to_vec::(length) +} + +#[test] +fn where_cond() { + let shape = vec![6]; + let cond = vec![0u8, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u8_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); +}