From 480a3e22e653be1c6cfd502dd607f920d516c404 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 7 Nov 2023 23:45:53 +0100 Subject: [PATCH] Adding cast + binary kernels. --- candle-core/src/metal_backend.rs | 224 +++++++++++---- candle-examples/examples/quantized/main.rs | 24 +- candle-metal-kernels/src/binary.metal | 78 +++++ candle-metal-kernels/src/cast.metal | 58 ++++ candle-metal-kernels/src/lib.rs | 271 +++++++++++++++++- candle-metal-kernels/src/unary.metal | 14 +- .../src/models/quantized_llama.rs | 16 +- 7 files changed, 601 insertions(+), 84 deletions(-) create mode 100644 candle-metal-kernels/src/binary.metal create mode 100644 candle-metal-kernels/src/cast.metal diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c400b59c..7056e500 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -113,11 +113,11 @@ impl BackendStorage for MetalStorage { debug!("{shape:?} {el:?} {:?}", layout.stride()); let output_buffer = device.new_buffer(el, self.dtype); - // return Ok(Self { - // buffer: output_buffer, - // device: device.clone(), - // dtype, - // }); + return Ok(Self { + buffer: output_buffer, + device: device.clone(), + dtype, + }); let function = self .device .kernels @@ -185,9 +185,9 @@ impl BackendStorage for MetalStorage { start.elapsed() ); - let capture = metal::CaptureManager::shared(); - capture.stop_capture(); - panic!("Done"); + // let capture = metal::CaptureManager::shared(); + // capture.stop_capture(); + // panic!("Done"); Ok(Self { buffer: output_buffer, @@ -283,7 +283,58 @@ impl BackendStorage for MetalStorage { } fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { - todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype) + let device = self.device(); + let shape = layout.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + // TODO remove + // return Ok(Self { + // buffer, + // device: device.clone(), + // dtype, + // }); + let command_buffer = device.command_queue.new_command_buffer(); + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32", + (left, right) => todo!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + todo!( + "TODO Implement the kernel calling cast {:?}-{:?}", + self.dtype, + dtype + ); + } + + let start = std::time::Instant::now(); + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "cast {:?} - {:?} - {:?} - {:?}", + dtype, + start.elapsed(), + self.buffer.length(), + buffer.length() + ); + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) } fn unary_impl(&self, layout: &Layout) -> Result { @@ -294,11 +345,11 @@ impl BackendStorage for MetalStorage { let el_count = shape.elem_count(); let mut buffer = device.new_buffer(el_count, dtype); // TODO remove - return Ok(Self { - buffer, - device: device.clone(), - dtype, - }); + // return Ok(Self { + // buffer, + // device: device.clone(), + // dtype, + // }); let command_buffer = device.command_queue.new_command_buffer(); if layout.is_contiguous() { use candle_metal_kernels::unary::contiguous; @@ -328,7 +379,7 @@ impl BackendStorage for MetalStorage { let start = std::time::Instant::now(); command_buffer.commit(); - command_buffer.wait_until_completed(); + // command_buffer.wait_until_scheduled(); debug!( "Unary {:?} - {:?} - {:?} - {:?}", B::KERNEL, @@ -344,10 +395,87 @@ impl BackendStorage for MetalStorage { }) } - fn binary_impl(&self, _: &Self, _: &Layout, _: &Layout) -> Result { - debug!("TODO Binary {:?}", B::NAME); - Ok(self.clone()) - // todo!() + fn binary_impl( + &self, + rhs: &Self, + lhs_l: &Layout, + rhs_l: &Layout, + ) -> Result { + let device = self.device(); + let dtype = self.dtype; + let shape = lhs_l.shape(); + let dims = shape.dims(); + let el_count = shape.elem_count(); + let mut buffer = device.new_buffer(el_count, dtype); + let command_buffer = device.command_queue.new_command_buffer(); + if lhs_l.is_contiguous() && rhs_l.is_contiguous() { + use candle_metal_kernels::binary::contiguous; + + let kernel_name = match (B::KERNEL, dtype) { + ("add", DType::F32) => contiguous::add::FLOAT, + ("badd", DType::F32) => contiguous::add::FLOAT, + ("sub", DType::F32) => contiguous::sub::FLOAT, + ("bsub", DType::F32) => contiguous::sub::FLOAT, + ("mul", DType::F32) => contiguous::mul::FLOAT, + ("bmul", DType::F32) => contiguous::mul::FLOAT, + ("div", DType::F32) => contiguous::div::FLOAT, + ("bdiv", DType::F32) => contiguous::div::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_contiguous( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + el_count, + &self.buffer, + &rhs.buffer, + &mut buffer, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::binary::strided; + + let kernel_name = match (B::KERNEL, dtype) { + ("badd", DType::F32) => strided::add::FLOAT, + ("bsub", DType::F32) => strided::sub::FLOAT, + ("bmul", DType::F32) => strided::mul::FLOAT, + ("bdiv", DType::F32) => strided::div::FLOAT, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_binary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + lhs_l.dims(), + &self.buffer, + &lhs_l.stride(), + lhs_l.start_offset(), + &rhs.buffer, + &rhs_l.stride(), + rhs_l.start_offset(), + &mut buffer, + ) + .map_err(MetalError::from)?; + } + + let start = std::time::Instant::now(); + command_buffer.commit(); + // command_buffer.wait_until_scheduled(); + debug!( + "Binary {:?} - {:?} - {:?} - {:?}", + B::KERNEL, + start.elapsed(), + self.buffer.length(), + buffer.length() + ); + + Ok(Self { + buffer, + device: device.clone(), + dtype, + }) } fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result { @@ -546,25 +674,25 @@ impl MetalStorage { } debug!("GEMM"); - let command_buffer = self.device.command_queue.new_command_buffer(); - encode_gemm::( - &self.device, - &command_buffer, - transpose_left, - transpose_right, - &self.buffer, - &rhs.buffer, - &mut out_buffer, - m as NSUInteger, - n as NSUInteger, - k as NSUInteger, - alpha, - beta, - ) - .map_err(MetalError::from)?; + // let command_buffer = self.device.command_queue.new_command_buffer(); + // encode_gemm::( + // &self.device, + // &command_buffer, + // transpose_left, + // transpose_right, + // &self.buffer, + // &rhs.buffer, + // &mut out_buffer, + // m as NSUInteger, + // n as NSUInteger, + // k as NSUInteger, + // alpha, + // beta, + // ) + // .map_err(MetalError::from)?; - command_buffer.commit(); - command_buffer.wait_until_scheduled(); + // command_buffer.commit(); + // command_buffer.wait_until_scheduled(); // println!("lhs {:?} {m} {k}", self.buffer.length()); // println!("rhs {:?} {k} {n}", rhs.buffer.length()); @@ -588,18 +716,18 @@ impl BackendDevice for MetalDevice { fn new(ordinal: usize) -> Result { let device = metal::Device::all().swap_remove(ordinal); - let capture = metal::CaptureManager::shared(); - let descriptor = metal::CaptureDescriptor::new(); - descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - println!("{:?}", std::env::current_dir()?); - descriptor.set_capture_device(&device); - let mut dir = std::env::current_dir()?; - dir.push("out.gputrace"); - descriptor.set_output_url(dir); + // let capture = metal::CaptureManager::shared(); + // let descriptor = metal::CaptureDescriptor::new(); + // descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + // println!("{:?}", std::env::current_dir()?); + // descriptor.set_capture_device(&device); + // let mut dir = std::env::current_dir()?; + // dir.push("out.gputrace"); + // descriptor.set_output_url(dir); - capture - .start_capture(&descriptor) - .map_err(MetalError::from)?; + // capture + // .start_capture(&descriptor) + // .map_err(MetalError::from)?; let command_queue = device.new_command_queue(); // let command_buffer = _command_queue.new_owned_command_buffer(); let kernels = Arc::new(Kernels::new()); diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index f5812536..05ecf41c 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -239,15 +239,13 @@ fn main() -> anyhow::Result<()> { Some(args.temperature) }; tracing_subscriber::fmt::init(); - // let _guard = if args.tracing { - // // let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - // // tracing_subscriber::registry().with(chrome_layer).init(); - // tracing_subscriber::fmt::init(); - // None - // // Some(guard) - // } else { - // None - // }; + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; println!( "avx: {}, neon: {}, simd128: {}, f16c: {}", @@ -375,7 +373,8 @@ fn main() -> anyhow::Result<()> { let logits = logits.squeeze(0)?; // TODO Remove this once implementation is finished. let logits = logits.ones_like()?; - logits_processor.sample(&logits)? + // logits_processor.sample(&logits)? + 15043 }; let prompt_dt = start_prompt_processing.elapsed(); all_tokens.push(next_token); @@ -399,8 +398,9 @@ fn main() -> anyhow::Result<()> { )? }; // TODO Remove this once implementation is finished. - let logits = logits.ones_like()?; - next_token = logits_processor.sample(&logits)?; + // let logits = logits.ones_like()?; + // next_token = logits_processor.sample(&logits)?; + let next_token = 15043; all_tokens.push(next_token); print_token(next_token, &tokenizer); if next_token == eos_token { diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal new file mode 100644 index 00000000..7e0aa5d6 --- /dev/null +++ b/candle-metal-kernels/src/binary.metal @@ -0,0 +1,78 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + +using namespace metal; + +#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + TYPENAME x = left[i]; \ + TYPENAME y = right[i]; \ + output[i] = OUT_TYPENAME(FN); \ + } \ +}\ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *left_strides, \ + constant size_t *right_strides, \ + device const TYPENAME *left, \ + device const TYPENAME *right, \ + device TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \ + TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \ + output[i] = OUT_TYPENAME(FN); \ + } \ +} + +#define BINARY_OP(FN, NAME) \ +BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \ +BINARY(FN, half, half, NAME##_half, NAME##_half_strided); + +#define BFLOAT_BINARY_OP(FN, NAME) \ +BINARY(NAME, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided); + + +BINARY_OP(x + y, add) +BINARY_OP(x - y, sub) +BINARY_OP(x * y, mul) +BINARY_OP(x / y, div) + +#if __METAL_VERSION__ >= 310 +BFLOAT_BINARY_OP(x + y, badd) +BFLOAT_BINARY_OP(x - y, bsub) +BFLOAT_BINARY_OP(x * y, bmul) +BFLOAT_BINARY_OP(x / y, bdiv) +#endif diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal new file mode 100644 index 00000000..52e63662 --- /dev/null +++ b/candle-metal-kernels/src/cast.metal @@ -0,0 +1,58 @@ +#include + +METAL_FUNC uint get_strided_index( + uint idx, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides +) { + uint strided_i = 0; + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} + + +using namespace metal; + +#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \ +kernel void FN_NAME( \ + constant size_t &dim, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = RIGHT_TYPENAME(input[i]); \ + } \ +} \ +kernel void FN_NAME_STRIDED( \ + constant size_t &dim, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + device const LEFT_TYPENAME *input, \ + device RIGHT_TYPENAME *output, \ + uint threadgroup_size [[threads_per_threadgroup]], \ + uint thread_index [[thread_index_in_threadgroup]] \ +) { \ + const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ + const size_t start = thread_index * length; \ + const size_t stop = min(start + length, dim); \ + for (size_t i = start; i < stop; i++){ \ + output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + } \ +} + + +CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) + +#if __METAL_VERSION__ >= 310 +#endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index ce98334a..ba818819 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -9,15 +9,19 @@ use std::sync::RwLock; const AFFINE: &str = include_str!("affine.metal"); const INDEXING: &str = include_str!("indexing.metal"); const UNARY: &str = include_str!("unary.metal"); +const BINARY: &str = include_str!("binary.metal"); +const CAST: &str = include_str!("cast.metal"); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum Source { Affine, Indexing, Unary, + Binary, + Cast, } -macro_rules! unary{ +macro_rules! ops{ ($($name:ident),+) => { pub mod contiguous { @@ -47,7 +51,10 @@ macro_rules! unary{ } pub mod unary { - unary!(cos, sin, exp, sqr, sqrt, neg); + ops!(cos, sin, exp, sqr, sqrt, neg); +} +pub mod binary { + ops!(add, sub, mul, div); } // static LIBRARY_SOURCES: Lazy> = Lazy::new(|| { @@ -109,7 +116,9 @@ impl Kernels { match source { Source::Affine => AFFINE, Source::Unary => UNARY, + Source::Binary => BINARY, Source::Indexing => INDEXING, + Source::Cast => CAST, } } @@ -234,10 +243,9 @@ pub fn call_unary_strided( (strides.len() * std::mem::size_of::()) as u64, strides.as_ptr() as *const c_void, ); - encoder.set_bytes(4, std::mem::size_of::() as u64, void_ptr(&offset)); - encoder.set_buffer(5, Some(&input), 0); - encoder.set_buffer(6, Some(&output), 0); + encoder.set_buffer(4, Some(&input), offset as u64); + encoder.set_buffer(5, Some(&output), 0); let width = output.length(); @@ -258,6 +266,170 @@ pub fn call_unary_strided( Ok(()) } +pub fn call_binary_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: binary::contiguous::Kernel, + length: usize, + left: &Buffer, + right: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Binary, kernel_name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(&left), 0); + encoder.set_buffer(2, Some(&right), 0); + encoder.set_buffer(3, Some(&output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_binary_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + name: binary::strided::Kernel, + shape: &[usize], + left_input: &Buffer, + left_strides: &[usize], + left_offset: usize, + right_input: &Buffer, + right_strides: &[usize], + right_offset: usize, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + let func = kernels.load_function(device, Source::Binary, name.0)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let num_dims: usize = shape.len() as usize; + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + encoder.set_bytes(0, std::mem::size_of::() as u64, void_ptr(&length)); + encoder.set_bytes(1, std::mem::size_of::() as u64, void_ptr(&num_dims)); + encoder.set_bytes( + 2, + (shape.len() * std::mem::size_of::()) as u64, + shape.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 3, + (left_strides.len() * std::mem::size_of::()) as u64, + left_strides.as_ptr() as *const c_void, + ); + encoder.set_bytes( + 4, + (right_strides.len() * std::mem::size_of::()) as u64, + right_strides.as_ptr() as *const c_void, + ); + + encoder.set_buffer(5, Some(&left_input), left_offset as u64); + encoder.set_buffer(6, Some(&right_input), right_offset as u64); + encoder.set_buffer(7, Some(&output), 0); + + let width = output.length(); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width), + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + +pub fn call_cast_contiguous( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + length: usize, + input: &Buffer, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Cast, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + encoder.set_bytes(0, 4, void_ptr(&length)); + encoder.set_buffer(1, Some(&input), 0); + encoder.set_buffer(2, Some(&output), 0); + + let thread_group_count = MTLSize { + width: 1, + height: 1, + depth: 1, + }; + + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn void_ptr(v: &T) -> *const c_void { (v as *const T).cast() } @@ -310,6 +482,39 @@ mod tests { output.read_to_vec::(v.len()) } + fn run_binary(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let left = device.new_buffer_with_data( + x.as_ptr() as *const core::ffi::c_void, + (x.len() * core::mem::size_of::()) as u64, + options, + ); + let right = device.new_buffer_with_data( + y.as_ptr() as *const core::ffi::c_void, + (y.len() * core::mem::size_of::()) as u64, + options, + ); + let mut output = device.new_buffer((x.len() * core::mem::size_of::()) as u64, options); + call_binary_contiguous( + &device, + &command_buffer, + &kernels, + name, + x.len(), + &left, + &right, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(x.len()) + } + fn run_strided( v: &[T], kernel: unary::strided::Kernel, @@ -421,6 +626,62 @@ mod tests { assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); } + #[test] + fn binary_add_f32() { + let left = vec![1.0f32, 2.0, 3.0]; + let right = vec![2.0f32, 3.1, 4.2]; + let results = run_binary(&left, &right, binary::contiguous::add::FLOAT); + let expected: Vec<_> = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| x + y) + .collect(); + assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]); + assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]); + } + + fn cast(v: &[T], name: &'static str) -> Vec { + let device = device(); + let kernels = Kernels::new(); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let options = MTLResourceOptions::StorageModeManaged; + let input = device.new_buffer_with_data( + v.as_ptr() as *const core::ffi::c_void, + (v.len() * core::mem::size_of::()) as u64, + options, + ); + let mut output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, options); + call_cast_contiguous( + &device, + &command_buffer, + &kernels, + name, + v.len(), + &input, + &mut output, + ) + .unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + + #[test] + fn cast_u32_f32() { + let v = vec![1u32, 2, 3]; + let results = cast(&v, "cast_u32_f32"); + let expected: Vec<_> = v.iter().map(|&v| v as f32).collect(); + assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]); + assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]); + + let v = vec![1.0f32; 10_000]; + let results = run(&v, unary::contiguous::cos::FLOAT); + let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); + assert_eq!(approx(results, 4), vec![0.5403; 10_000]); + assert_eq!(approx(expected, 4), vec![0.5403; 10_000]); + } + #[test] fn affine() { let device = device(); diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index f30fb929..03f88779 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -1,19 +1,12 @@ #include -struct Info{ - device size_t &num_dims; - device size_t *dims; - device size_t *strides; -}; - METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, constant size_t *dims, - constant size_t *strides, - constant size_t &offset + constant size_t *strides ) { - uint strided_i = offset; + uint strided_i = 0; for (uint d = 0; d < num_dims; d++) { uint dim_idx = num_dims - 1 - d; strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; @@ -48,7 +41,6 @@ kernel void FN_NAME_STRIDED( \ constant size_t &num_dims, \ constant size_t *dims, \ constant size_t *strides, \ - constant size_t &offset, \ device const TYPENAME *input, \ device TYPENAME *output, \ uint threadgroup_size [[threads_per_threadgroup]], \ @@ -58,7 +50,7 @@ kernel void FN_NAME_STRIDED( \ const size_t start = thread_index * length; \ const size_t stop = min(start + length, dim); \ for (size_t i = start; i < stop; i++){ \ - output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides, offset)])); \ + output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ } \ } diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 3685d3de..8c04093a 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; -use candle::{Device, IndexOp, Result, Tensor, D}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module}; pub const MAX_SEQ_LEN: usize = 4096; @@ -196,15 +196,15 @@ fn precomput_freqs_cis( .collect(); let theta = Tensor::new(theta.as_slice(), device)?; let range: Vec = (0..MAX_SEQ_LEN).map(|r| r as f32).collect(); - let idx_theta = Tensor::new(range.as_slice(), device)? - .reshape((MAX_SEQ_LEN, 1))? - .matmul(&theta.reshape((1, theta.elem_count()))?)?; - // TODO This change avoids allocating on Metal and then casting since allocating directly on - // CPU as f32 seems just as fast - // let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? - // .to_dtype(DType::F32)? + // let idx_theta = Tensor::new(range.as_slice(), device)? // .reshape((MAX_SEQ_LEN, 1))? // .matmul(&theta.reshape((1, theta.elem_count()))?)?; + // TODO This change avoids allocating on Metal and then casting since allocating directly on + // CPU as f32 seems just as fast + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; let cos = idx_theta.cos()?; let sin = idx_theta.sin()?; Ok((cos, sin))