From 54355ff99795f8aa1de371681e1552212cfae932 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Sat, 11 Nov 2023 13:30:21 +0100 Subject: [PATCH] Adding some half kernels. --- candle-core/src/metal_backend.rs | 128 ++++++++++++++++++++++-- candle-metal-kernels/src/cast.metal | 2 + candle-metal-kernels/src/indexing.metal | 1 + candle-metal-kernels/src/lib.rs | 46 ++++++++- 4 files changed, 165 insertions(+), 12 deletions(-) diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index f363a84b..03e6d810 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -153,11 +153,16 @@ impl BackendStorage for MetalStorage { let mut buffer = device.new_buffer(el, self.dtype); let command_buffer = self.device.command_buffer(); if layout.is_contiguous() && layout.start_offset() == 0 { - assert_eq!(dtype, DType::F32); + let name = match self.dtype { + DType::F32 => "affine_float", + DType::F16 => "affine_half", + dtype => todo!("Affine {dtype:?}"), + }; candle_metal_kernels::call_affine( &device.device, &command_buffer, &device.kernels, + name, el, &self.buffer, &mut buffer, @@ -166,11 +171,16 @@ impl BackendStorage for MetalStorage { ) .unwrap(); } else { - assert_eq!(dtype, DType::F32); + let name = match self.dtype { + DType::F32 => "affine_float", + DType::F16 => "affine_half", + dtype => todo!("Affine {dtype:?}"), + }; candle_metal_kernels::call_affine_strided( &device.device, &command_buffer, &device.kernels, + name, layout.dims(), &self.buffer, layout.stride(), @@ -273,6 +283,8 @@ impl BackendStorage for MetalStorage { if layout.is_contiguous() { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32", + (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F16, DType::F32) => "cast_f16_f32", (left, right) => todo!("to dtype {left:?} - {right:?}"), }; candle_metal_kernels::call_cast_contiguous( @@ -286,11 +298,24 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!( - "TODO Implement the kernel calling cast {:?}-{:?}", - self.dtype, - dtype - ); + let kernel_name = match (self.dtype, dtype) { + (DType::U32, DType::F32) => "cast_u32_f32_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (left, right) => todo!("to dtype {left:?} - {right:?}"), + }; + candle_metal_kernels::call_cast_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &mut buffer, + ) + .map_err(MetalError::from)?; } // command_buffer.commit(); @@ -327,6 +352,20 @@ impl BackendStorage for MetalStorage { ("uceil", DType::F32) => contiguous::ceil::FLOAT, ("ufloor", DType::F32) => contiguous::floor::FLOAT, ("uround", DType::F32) => contiguous::round::FLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + // TODO erf does not exist in metal + ("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"), + ("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"), + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("uround", DType::F16) => contiguous::round::HALF, (name, dtype) => todo!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_unary_contiguous( @@ -340,7 +379,51 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } else { - todo!("TODO Implement the kernel calling {}", B::KERNEL); + use candle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + // TODO erf does not exist in metal + ("ugelu_erf", DType::F32) => crate::bail!("erf is not implemented in metal"), + ("uerf", DType::F32) => crate::bail!("erf is not implemented in metal"), + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + // TODO erf does not exist in metal + ("ugelu_erf", DType::F16) => crate::bail!("erf is not implemented in metal"), + ("uerf", DType::F16) => crate::bail!("erf is not implemented in metal"), + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("uround", DType::F16) => strided::round::HALF, + (name, dtype) => todo!("Match {name} - {dtype:?}"), + }; + candle_metal_kernels::call_unary_strided( + &device.device, + &command_buffer, + &device.kernels, + kernel_name, + layout.dims(), + &self.buffer, + layout.stride(), + layout.start_offset() * self.dtype.size_in_bytes(), + &mut buffer, + 0, + ) + .map_err(MetalError::from)?; } // command_buffer.commit(); // command_buffer.wait_until_scheduled(); @@ -378,6 +461,14 @@ impl BackendStorage for MetalStorage { ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, ("bdiv", DType::F32) => contiguous::div::FLOAT, + ("add", DType::F16) => contiguous::add::HALF, + ("badd", DType::F16) => contiguous::add::HALF, + ("sub", DType::F16) => contiguous::sub::HALF, + ("bsub", DType::F16) => contiguous::sub::HALF, + ("mul", DType::F16) => contiguous::mul::HALF, + ("bmul", DType::F16) => contiguous::mul::HALF, + ("div", DType::F16) => contiguous::div::HALF, + ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => todo!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -399,6 +490,10 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, + ("badd", DType::F16) => strided::add::HALF, + ("bsub", DType::F16) => strided::sub::HALF, + ("bmul", DType::F16) => strided::mul::HALF, + ("bdiv", DType::F16) => strided::div::HALF, (name, dtype) => todo!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_strided( @@ -555,6 +650,7 @@ impl BackendStorage for MetalStorage { let mut buffer = device.new_buffer(dst_el, dtype); let name = match (ids.dtype, self.dtype) { (DType::U32, DType::F32) => "is_u32_f32", + (DType::U32, DType::F16) => "is_u32_f16", (left, right) => todo!("index select metal {left:?} {right:?}"), }; let command_buffer = self.device.command_buffer(); @@ -601,8 +697,20 @@ impl BackendStorage for MetalStorage { ) -> Result { // Create descriptors use metal::mps::matrix::*; - let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32; - let size = core::mem::size_of::() as NSUInteger; + + assert_eq!(self.dtype, rhs.dtype); + + let (type_id, size) = match self.dtype { + DType::F32 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 32, + core::mem::size_of::() as NSUInteger, + ), + DType::F16 => ( + metal::mps::MPS_FLOATBIT_ENCODING | 16, + core::mem::size_of::() as NSUInteger, + ), + dtype => todo!("Implement matmul {dtype:?}"), + }; let elem_count = b * m * n; diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index d1788253..bd49bdcc 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -46,6 +46,8 @@ kernel void FN_NAME_STRIDED( \ } \ CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) +CAST(cast_f16_f32, cast_f16_f32_strided, half, float) +CAST(cast_f32_f16, cast_f32_f16_strided, float, half) #if __METAL_VERSION__ >= 310 #endif diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 444fa322..e0129ca9 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -75,6 +75,7 @@ kernel void FN_NAME( \ INDEX_OP(is_u32_f32, uint, float) +INDEX_OP(is_u32_f16, uint, half) #if __METAL_VERSION__ >= 310 diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6cdb313d..c4a0ca97 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -460,6 +460,46 @@ pub fn call_cast_contiguous( Ok(()) } +pub fn call_cast_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + input: &Buffer, + input_strides: &[usize], + input_offset: usize, + output: &mut Buffer, +) -> Result<(), MetalKernelError> { + // println!("Kernel {:?}", kernel_name.0); + // assert_eq!(input.length(), output.length()); + let func = kernels.load_function(device, Source::Cast, kernel_name)?; + let pipeline_state_descriptor = ComputePipelineDescriptor::new(); + pipeline_state_descriptor.set_compute_function(Some(&func)); + + let pipeline = device + .new_compute_pipeline_state_with_function( + pipeline_state_descriptor.compute_function().unwrap(), + ) + .unwrap(); + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.set_compute_pipeline_state(&pipeline); + + let length: usize = shape.iter().product(); + + set_params!( + encoder, + (length, shape, input_strides, (input, input_offset), output) + ); + + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.end_encoding(); + Ok(()) +} + pub fn call_reduce_contiguous( device: &Device, command_buffer: &CommandBufferRef, @@ -565,13 +605,14 @@ pub fn call_affine( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, + name: &'static str, size: usize, input: &Buffer, output: &mut Buffer, mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, "affine_float")?; + let func = kernels.load_function(device, Source::Affine, name)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -596,6 +637,7 @@ pub fn call_affine_strided( device: &Device, command_buffer: &CommandBufferRef, kernels: &Kernels, + name: &'static str, shape: &[usize], input: &Buffer, input_stride: &[usize], @@ -604,7 +646,7 @@ pub fn call_affine_strided( mul: f32, add: f32, ) -> Result<(), MetalKernelError> { - let func = kernels.load_function(device, Source::Affine, "affine_float_strided")?; + let func = kernels.load_function(device, Source::Affine, name)?; let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func));