diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index de57c03a..73eb9640 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -146,6 +146,7 @@ impl Device { match (self, rhs) { (Self::Cpu, Self::Cpu) => true, (Self::Cuda(lhs), Self::Cuda(rhs)) => lhs.same_device(rhs), + (Self::Metal(lhs), Self::Metal(rhs)) => lhs.same_device(rhs), _ => false, } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index da61bdb5..36f5f6b1 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -53,6 +53,8 @@ mod dummy_metal_backend; pub mod error; mod indexer; pub mod layout; +#[cfg(feature = "metal")] +pub mod metal_backend; #[cfg(feature = "mkl")] mod mkl; pub mod npy; diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 04a2c3dd..68a96672 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1,17 +1,16 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; +use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; -use candle_metal_kernels::{void_ptr, Kernels, Source}; +use candle_metal_kernels::Kernels; use core::mem; use half::{bf16, f16}; use metal; use metal::mps::matrix::encode_gemm; use metal::mps::Float32; -use metal::{Buffer, CommandQueue, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger}; +use metal::{Buffer, CommandQueue, MTLResourceOptions, NSUInteger}; use std::sync::Arc; -use tracing::debug; /// Metal related errors #[derive(thiserror::Error, Debug)] @@ -113,7 +112,6 @@ impl BackendStorage for MetalStorage { let device = self.device().clone(); let shape = layout.shape(); - let dims = shape.dims(); let el = shape.elem_count(); let dtype = self.dtype; @@ -174,10 +172,8 @@ impl BackendStorage for MetalStorage { stride.push(src_stride[dim_idx]); } - let el_to_sum_per_block = src_el / dst_el; // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. - let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two(); let (name, check_empty, return_index) = match (op, self.dtype) { (ReduceOp::Sum, DType::F32) => ("fast_sum_float", false, false), (ReduceOp::Min, DType::F32) => ("fast_min_float", true, false), @@ -219,13 +215,10 @@ impl BackendStorage for MetalStorage { fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { let device = self.device(); let shape = layout.shape(); - let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); if layout.is_contiguous() { - use candle_metal_kernels::unary::contiguous; - let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", (left, right) => todo!("to dtype {left:?} - {right:?}"), @@ -250,12 +243,12 @@ impl BackendStorage for MetalStorage { command_buffer.commit(); // command_buffer.wait_until_scheduled(); - debug!( - "cast {:?} - {:?} - {:?}", - dtype, - self.buffer.length(), - buffer.length() - ); + // debug!( + // "cast {:?} - {:?} - {:?}", + // dtype, + // self.buffer.length(), + // buffer.length() + // ); Ok(Self { buffer, device: device.clone(), @@ -267,15 +260,8 @@ impl BackendStorage for MetalStorage { let device = self.device(); let dtype = self.dtype; let shape = layout.shape(); - let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); - // TODO remove - // return Ok(Self { - // buffer, - // device: device.clone(), - // dtype, - // }); let command_buffer = device.command_queue.new_command_buffer(); if layout.is_contiguous() { use candle_metal_kernels::unary::contiguous; @@ -302,17 +288,7 @@ impl BackendStorage for MetalStorage { } else { todo!("TODO Implement the kernel calling {}", B::KERNEL); } - - let start = std::time::Instant::now(); command_buffer.commit(); - // command_buffer.wait_until_scheduled(); - debug!( - "Unary {:?} - {:?} - {:?} - {:?}", - B::KERNEL, - start.elapsed(), - self.buffer.length(), - buffer.length() - ); Ok(Self { buffer, @@ -330,7 +306,6 @@ impl BackendStorage for MetalStorage { let device = self.device(); let dtype = self.dtype; let shape = lhs_l.shape(); - let dims = shape.dims(); let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); let command_buffer = device.command_queue.new_command_buffer(); @@ -385,17 +360,7 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - - let start = std::time::Instant::now(); command_buffer.commit(); - // command_buffer.wait_until_scheduled(); - debug!( - "Binary {:?} - {:?} - {:?} - {:?}", - B::KERNEL, - start.elapsed(), - self.buffer.length(), - buffer.length() - ); Ok(Self { buffer, @@ -452,6 +417,16 @@ impl BackendStorage for MetalStorage { todo!() } + fn conv_transpose1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &ParamsConvTranspose1D, + ) -> Result { + todo!() + } + fn conv2d( &self, _l: &Layout, @@ -504,34 +479,28 @@ impl BackendStorage for MetalStorage { todo!() } - fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result { - debug!( - "TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}", - self.buffer.length(), - ids.buffer.length(), - ); - let src = self; - let ids_shape = ids_l.shape(); - let ids_dims = ids_shape.dims(); - // let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?; - // let src = match src_l.contiguous_offsets() { - // Some((o1, o2)) => src.slice(o1..o2), - // None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, - // }; - let left_size: usize = src_l.dims()[..dim].iter().product(); - let right_size: usize = src_l.dims()[dim + 1..].iter().product(); - let src_dim_size = src_l.dims()[dim]; - let ids_dim_size = ids_shape.elem_count(); - let dst_el = ids_shape.elem_count() * left_size * right_size; - let dtype = self.dtype; - let device = self.device(); - let buffer = device.new_buffer(dst_el, dtype); - Ok(Self { - buffer, - device: device.clone(), - dtype, - }) - // todo!() + fn index_select( + &self, + _ids: &Self, + _src_l: &Layout, + _ids_l: &Layout, + _dim: usize, + ) -> Result { + todo!("Index select"); + // let ids_shape = ids_l.shape(); + // let left_size: usize = src_l.dims()[..dim].iter().product(); + // let right_size: usize = src_l.dims()[dim + 1..].iter().product(); + // let src_dim_size = src_l.dims()[dim]; + // let ids_dim_size = ids_shape.elem_count(); + // let dst_el = ids_shape.elem_count() * left_size * right_size; + // let dtype = self.dtype; + // let device = self.device(); + // let buffer = device.new_buffer(dst_el, dtype); + // Ok(Self { + // buffer, + // device: device.clone(), + // dtype, + // }) } fn index_add( @@ -571,7 +540,6 @@ impl BackendStorage for MetalStorage { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let src_shape = src_l.shape(); - let dims = src_shape.dims(); let el_count = src_shape.elem_count(); if el_count == 0 { return Ok(()); @@ -637,7 +605,7 @@ impl MetalStorage { (DType::F32, DType::F32) => { let mut out_buffer = self.device.new_buffer(elem_count, self.dtype); if b != 1 { - debug!("TODO implement batched matmul for B={b}"); + // debug!("TODO implement batched matmul for B={b}"); // bail!("Didn't implemented strided matmul yet"); return Ok(Self { buffer: out_buffer, @@ -646,12 +614,12 @@ impl MetalStorage { }); } if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { - debug!( - "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", - lhs_l.is_contiguous(), - rhs_l.is_contiguous(), - rhs_l - ); + // debug!( + // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}", + // lhs_l.is_contiguous(), + // rhs_l.is_contiguous(), + // rhs_l + // ); return Ok(Self { buffer: out_buffer, device: self.device.clone(), @@ -659,7 +627,7 @@ impl MetalStorage { }); } - debug!("TODO GEMM"); + // debug!("TODO GEMM"); let command_buffer = self.device.command_queue.new_command_buffer(); encode_gemm::( &self.device, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 2a0924b6..f7f66668 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1859,7 +1859,11 @@ impl Tensor { (Storage::Cpu(storage), Device::Cuda(cuda)) => { Storage::Cuda(cuda.storage_from_cpu_storage(storage)?) } + (Storage::Cpu(storage), Device::Metal(metal)) => { + Storage::Metal(metal.storage_from_cpu_storage(storage)?) + } (Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), + (Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?), (Storage::Cuda(storage), Device::Cuda(cuda)) => { // TODO: Avoid passing through the cpu storage here, especially if the gpu ids // are the same. diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 528c109d..eefaef34 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -1,6 +1,39 @@ #include using namespace metal; +kernel void is_u32_f32( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + + const device float *input, + const device uint *input_ids, + device float *output, + + uint gid [[ thread_position_in_grid ]] +) { + + if (gid >= dst_size) { + return; + } + + const size_t id_i = gid / right_size / left_size; + const size_t right_rank_i = gid % right_size; + const size_t left_rank_i = gid % left_size; + + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + const uint input_i = min(input_ids[id_i], (uint)(src_dim_size - 1)); + const size_t src_i = ((input_i * right_size) + right_rank_i) * left_size + left_rank_i; + + output[gid] = input[src_i]; + +} + + template void index_add( device I *ids [[buffer(0)]], diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d2c63115..1bcd56d1 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use metal::{ - Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library, - MTLSize, + Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, + Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -15,6 +15,70 @@ const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) { +

::set_param(encoder, position, data) +} +trait EncoderParam { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); +} +macro_rules! primitive { + ($type:ty) => { + impl EncoderParam for $type { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::<$type>() as u64, + &data as *const $type as *const c_void, + ); + } + } + }; +} +primitive!(usize); +primitive!(u32); +primitive!(f32); + +impl EncoderParam for &[T] { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + (core::mem::size_of::() * data.len()) as u64, + data.as_ptr() as *const T as *const c_void, + ); + } +} + +impl EncoderParam for &Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} +impl EncoderParam for (&Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} +impl EncoderParam for &mut Buffer { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data), 0); + } +} +impl EncoderParam for (&mut Buffer, usize) { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_buffer(position, Some(data.0), data.1 as u64); + } +} + +macro_rules! set_params { + ($encoder:ident, ($($param:expr),+)) => ( + let mut _index = 0; + $( + set_param($encoder, _index, $param); + _index += 1; + )* + ); +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, @@ -191,9 +255,7 @@ pub fn call_unary_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, 4, void_ptr(&length)); - encoder.set_buffer(1, Some(input), 0); - encoder.set_buffer(2, Some(output), 0); + set_params!(encoder, (length, input, output)); let thread_group_count = MTLSize { width: 1, @@ -239,24 +301,19 @@ pub fn call_unary_strided( encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); - encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); - encoder.set_bytes( - 2, - std::mem::size_of_val(shape) as u64, - shape.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 3, - std::mem::size_of_val(strides) as u64, - strides.as_ptr() as *const c_void, + set_params!( + encoder, + ( + length, + num_dims, + shape, + strides, + (input, offset), + (output, output_offset) + ) ); - encoder.set_buffer(4, Some(input), offset as u64); - encoder.set_buffer(5, Some(output), output_offset as u64); - - let width = output.length(); - + let width: usize = shape.iter().product(); let thread_group_count = MTLSize { width: 1, height: 1, @@ -264,7 +321,7 @@ pub fn call_unary_strided( }; let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), height: 1, depth: 1, }; @@ -299,10 +356,7 @@ pub fn call_binary_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, 4, void_ptr(&length)); - encoder.set_buffer(1, Some(left), 0); - encoder.set_buffer(2, Some(right), 0); - encoder.set_buffer(3, Some(output), 0); + set_params!(encoder, (length, left, right, output)); let thread_group_count = MTLSize { width: 1, @@ -348,32 +402,24 @@ pub fn call_binary_strided( let num_dims: usize = shape.len(); let encoder = command_buffer.new_compute_command_encoder(); + let width: usize = shape.iter().product(); encoder.set_compute_pipeline_state(&pipeline); let length: usize = shape.iter().product(); - encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); - encoder.set_bytes( - 2, - std::mem::size_of_val(shape) as u64, - shape.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 3, - std::mem::size_of_val(left_strides) as u64, - left_strides.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 4, - std::mem::size_of_val(right_strides) as u64, - right_strides.as_ptr() as *const c_void, - ); - encoder.set_buffer(5, Some(left_input), left_offset as u64); - encoder.set_buffer(6, Some(right_input), right_offset as u64); - encoder.set_buffer(7, Some(output), 0); - - let width = output.length(); + set_params!( + encoder, + ( + length, + num_dims, + shape, + left_strides, + right_strides, + (left_input, left_offset), + (right_input, right_offset), + output + ) + ); let thread_group_count = MTLSize { width: 1, @@ -382,7 +428,7 @@ pub fn call_binary_strided( }; let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), height: 1, depth: 1, }; @@ -416,9 +462,7 @@ pub fn call_cast_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, 4, void_ptr(&length)); - encoder.set_buffer(1, Some(input), 0); - encoder.set_buffer(2, Some(output), 0); + set_params!(encoder, (length, input, output)); let thread_group_count = MTLSize { width: 1, @@ -463,14 +507,7 @@ pub fn call_reduce_contiguous( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes( - 1, - core::mem::size_of::() as u64, - void_ptr(&elements_to_sum), - ); - encoder.set_buffer(2, Some(input), 0); - encoder.set_buffer(3, Some(output), 0); + set_params!(encoder, (length, elements_to_sum, input, output)); let thread_group_count = MTLSize { width: out_length as u64, @@ -518,14 +555,7 @@ pub fn call_last_softmax( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&length)); - encoder.set_bytes( - 1, - core::mem::size_of::() as u64, - void_ptr(&elements_to_sum), - ); - encoder.set_buffer(2, Some(input), 0); - encoder.set_buffer(3, Some(output), 0); + set_params!(encoder, (length, elements_to_sum, input, output)); let out_length = length / elements_to_sum; @@ -553,10 +583,6 @@ pub fn call_last_softmax( Ok(()) } -pub fn void_ptr(v: &T) -> *const c_void { - (v as *const T).cast() -} - pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, @@ -580,11 +606,7 @@ pub fn call_affine( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); - encoder.set_bytes(1, core::mem::size_of::() as u64, void_ptr(&mul)); - encoder.set_bytes(2, core::mem::size_of::() as u64, void_ptr(&add)); - encoder.set_buffer(3, Some(input), 0); - encoder.set_buffer(4, Some(output), 0); + set_params!(encoder, (size, mul, add, input, output)); let thread_group_count = MTLSize { width: 1, @@ -632,36 +654,23 @@ pub fn call_where_cond_strided( encoder.set_compute_pipeline_state(&pipeline); let size: usize = shape.iter().product(); - encoder.set_bytes(0, core::mem::size_of::() as u64, void_ptr(&size)); - encoder.set_bytes( - 1, - core::mem::size_of::() as u64, - void_ptr(&shape.len()), + let rank = shape.len(); + + set_params!( + encoder, + ( + size, + rank, + shape, + cond_stride, + left_stride, + right_stride, + (cond, cond_offset), + (left, left_offset), + (right, right_offset), + output + ) ); - encoder.set_bytes( - 2, - std::mem::size_of_val(shape) as u64, - shape.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 3, - std::mem::size_of_val(cond_stride) as u64, - cond_stride.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 4, - std::mem::size_of_val(left_stride) as u64, - left_stride.as_ptr() as *const c_void, - ); - encoder.set_bytes( - 5, - std::mem::size_of_val(right_stride) as u64, - right_stride.as_ptr() as *const c_void, - ); - encoder.set_buffer(6, Some(cond), cond_offset as u64); - encoder.set_buffer(7, Some(left), left_offset as u64); - encoder.set_buffer(8, Some(right), right_offset as u64); - encoder.set_buffer(9, Some(output), 0); let thread_group_count = MTLSize { width: 1, @@ -686,7 +695,13 @@ mod tests { use super::*; use half::f16; use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger}; - use std::mem; + + fn new_buffer(device: &Device, data: &[T]) -> Buffer { + let options = MTLResourceOptions::StorageModeManaged; + let ptr = data.as_ptr() as *const core::ffi::c_void; + let size = (data.len() * std::mem::size_of::()) as u64; + device.new_buffer_with_data(ptr, size, options) + } fn device() -> Device { Device::system_default().unwrap() @@ -707,13 +722,8 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); call_unary_contiguous( &device, command_buffer, @@ -735,16 +745,8 @@ mod tests { let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let options = MTLResourceOptions::StorageModeManaged; - let left = device.new_buffer_with_data( - x.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(x) as u64, - options, - ); - let right = device.new_buffer_with_data( - y.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(y) as u64, - options, - ); + let left = new_buffer(&device, x); + let right = new_buffer(&device, y); let mut output = device.new_buffer(std::mem::size_of_val(x) as u64, options); call_binary_contiguous( &device, @@ -770,15 +772,10 @@ mod tests { offset: usize, ) -> Vec { let device = device(); - let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); let kernels = Kernels::new(); call_unary_strided( &device, @@ -893,13 +890,9 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); + call_cast_contiguous( &device, command_buffer, @@ -935,14 +928,9 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); let size = v.len(); @@ -978,6 +966,104 @@ mod tests { assert_eq!(result, vec![2.6; 40_000]); } + #[test] + fn index_select() { + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 4, 2]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!(result, vec![1.0f32, 2.0, 9.0, 10.0, 5.0, 6.0]); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [2, 5]; + let ids = [0u32, 1, 0]; + let dim = 0; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); + + let embedding = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let shape = [5, 2]; + let ids = [0u32, 1, 0]; + let dim = 1; + let result = run_index_select(&embedding, &shape, &ids, dim); + assert_eq!( + result, + vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 1.0f32, 2.0, 3.0, 4.0, 5.0] + ); + } + + fn run_index_select( + embeddings: &[T], + shape: &[usize], + ids: &[I], + dim: usize, + ) -> Vec { + let device = Device::system_default().expect("no device found"); + let options = CompileOptions::new(); + let library = device.new_library_with_source(INDEXING, &options).unwrap(); + + let left_size: usize = shape[..dim].iter().product(); + let right_size: usize = shape[dim + 1..].iter().product(); + let src_dim_size = shape[dim]; + let dst_el = ids.len() * left_size * right_size; + let ids_size = ids.len(); + + let function = library.get_function("is_u32_f32", None).unwrap(); + let pipeline = device + .new_compute_pipeline_state_with_function(&function) + .unwrap(); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let encoder = command_buffer.new_compute_command_encoder(); + + encoder.set_compute_pipeline_state(&pipeline); + + let embeddings_buffer = new_buffer(&device, &embeddings); + let ids_buffer = new_buffer(&device, &ids); + let mut dst_buffer = new_buffer(&device, &vec![0.0f32; dst_el]); + + set_params!( + encoder, + ( + dst_el, + left_size, + src_dim_size, + right_size, + ids_size, + &embeddings_buffer, + &ids_buffer, + &mut dst_buffer + ) + ); + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); + let grid_size = MTLSize { + width: (dst_el as u64 + width - 1) / width, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + println!("{width:?} - {:?}", grid_size); + + encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.end_encoding(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + dst_buffer.read_to_vec::(dst_el) + } + #[test] fn index_add() { let device = Device::system_default().expect("no device found"); @@ -997,31 +1083,29 @@ mod tests { let pipeline = device .new_compute_pipeline_state_with_function(&function) .unwrap(); - let options = MTLResourceOptions::StorageModeManaged; let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); let encoder = command_buffer.new_compute_command_encoder(); - let ids_size = (index.len() * mem::size_of::()) as NSUInteger; - let input_size = (left.len() * mem::size_of::()) as NSUInteger; - let output_size = (right.len() * mem::size_of::()) as NSUInteger; - encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, output_size as NSUInteger); - let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options); - let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options); - let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options); + let index_buffer = new_buffer(&device, &index); + let inputs_buffer = new_buffer(&device, &left); + let outputs_buffer = new_buffer(&device, &right); - encoder.set_buffer(0, Some(&index_buffer), 0); - encoder.set_buffer(1, Some(&inputs_buffer), 0); - encoder.set_buffer(2, Some(&outputs_buffer), 0); - - encoder.set_bytes(3, 4, void_ptr(&ids_dim_size)); - encoder.set_bytes(4, 4, void_ptr(&left_size)); - encoder.set_bytes(5, 4, void_ptr(&dst_dim_size)); - encoder.set_bytes(6, 4, void_ptr(&right_size)); + set_params!( + encoder, + ( + &index_buffer, + &inputs_buffer, + &outputs_buffer, + ids_dim_size, + left_size, + dst_dim_size, + right_size + ) + ); let grid_size = MTLSize { width: right.len() as NSUInteger, @@ -1064,12 +1148,9 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); + let input = new_buffer(&device, v); + let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); let mut output = device.new_buffer((out_length * core::mem::size_of::()) as u64, options); call_reduce_contiguous( @@ -1098,13 +1179,8 @@ mod tests { let kernels = Kernels::new(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let options = MTLResourceOptions::StorageModeManaged; - let input = device.new_buffer_with_data( - v.as_ptr() as *const core::ffi::c_void, - std::mem::size_of_val(v) as u64, - options, - ); - let mut output = device.new_buffer(std::mem::size_of_val(v) as u64, options); + let input = new_buffer(&device, v); + let mut output = new_buffer(&device, v); call_last_softmax( &device, command_buffer,