From dedc8c36565bbd803028bd35399fdd38d5e17ff8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 6 Nov 2023 15:36:48 +0100 Subject: [PATCH] Writing unary as macro instead, protecting bfloat type with proper metal version. --- candle-metal-kernels/Cargo.toml | 3 ++ candle-metal-kernels/src/indexing.metal | 6 ++- candle-metal-kernels/src/lib.rs | 57 ++++++++++++++----------- candle-metal-kernels/src/unary.metal | 22 +++++++--- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/candle-metal-kernels/Cargo.toml b/candle-metal-kernels/Cargo.toml index bf505624..6b0939e5 100644 --- a/candle-metal-kernels/Cargo.toml +++ b/candle-metal-kernels/Cargo.toml @@ -12,3 +12,6 @@ license.workspace = true metal = { workspace = true } once_cell = "1.18.0" thiserror = { workspace = true } + +[dev-dependencies] +half = { workspace = true } diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 2c80e556..528c109d 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -48,9 +48,13 @@ kernel void FN_NAME( \ uint thread_index [[thread_index_in_threadgroup]] \ ) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ + + +#if __METAL_VERSION__ >= 310 IA_OP(bfloat, int64_t, ia_i64_bf16) IA_OP(bfloat, uint32_t, ia_u32_bf16) IA_OP(bfloat, uint8_t, ia_u8_bf16) +#endif IA_OP(half, uint32_t, ia_u32_f16) IA_OP(half, uint8_t, ia_u8_f16) @@ -68,4 +72,4 @@ IA_OP(uint32_t, uint32_t, ia_u32_u32) IA_OP(float, uint8_t, ia_u8_f32) IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) -IA_OP(int64_t, uint8_t, ia_u8_i64) \ No newline at end of file +IA_OP(int64_t, uint8_t, ia_u8_i64) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f12edc04..db25e6f3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -131,11 +131,10 @@ pub fn void_ptr(v: &T) -> *const c_void { #[cfg(test)] mod tests { use super::*; + use half::f16; use metal::{ - CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLResourceUsage, - MTLSize, NSUInteger, + CompileOptions, ComputePipelineDescriptor, Device, MTLResourceOptions, MTLSize, NSUInteger, }; - use std::ffi::c_void; use std::mem; fn approx(v: Vec, digits: i32) -> Vec { @@ -143,32 +142,26 @@ mod tests { v.iter().map(|t| f32::round(t * b) / b).collect() } - #[test] - fn cos() { - let v = vec![1.0f32, 2.0, 3.0]; + fn approx_f16(v: Vec, digits: i32) -> Vec { + let b = 10f32.powi(digits); + v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect() + } + + fn run_cos(v: &[T], name: &str) -> Vec { let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache; let device = Device::system_default().unwrap(); let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let encoder = command_buffer.new_compute_command_encoder(); let input = device.new_buffer_with_data( - v.as_ptr() as *const c_void, - (v.len() * mem::size_of::()) as u64, + v.as_ptr() as *const core::ffi::c_void, + (v.len() * core::mem::size_of::()) as u64, option, ); - let output = device.new_buffer((v.len() * mem::size_of::()) as u64, option); + let output = device.new_buffer((v.len() * core::mem::size_of::()) as u64, option); let library = device .new_library_with_source(UNARY, &CompileOptions::new()) .expect("Failed to load unary library"); - let func = library.get_function("cos", None).unwrap(); - let argument_encoder = func.new_argument_encoder(0); - let arg_buffer = device.new_buffer( - argument_encoder.encoded_length(), - MTLResourceOptions::StorageModeShared, - ); - argument_encoder.set_argument_buffer(&arg_buffer, 0); - argument_encoder.set_buffer(0, &input, 0); - argument_encoder.set_buffer(1, &output, 0); + let func = library.get_function(&format!("cos_{name}"), None).unwrap(); let pipeline_state_descriptor = ComputePipelineDescriptor::new(); pipeline_state_descriptor.set_compute_function(Some(&func)); @@ -178,11 +171,10 @@ mod tests { ) .unwrap(); + let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline_state); - encoder.set_buffer(0, Some(&arg_buffer), 0); - - encoder.use_resource(&input, MTLResourceUsage::Read); - encoder.use_resource(&output, MTLResourceUsage::Write); + encoder.set_buffer(0, Some(&input), 0); + encoder.set_buffer(1, Some(&output), 0); let width = 16; @@ -202,9 +194,14 @@ mod tests { encoder.end_encoding(); command_buffer.commit(); command_buffer.wait_until_completed(); + output.read_to_vec::(v.len()) + } + #[test] + fn cos_f32() { + let v = vec![1.0f32, 2.0, 3.0]; + let results = run_cos(&v, "float"); let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); - let results = output.read_to_vec::(v.len()); assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); } @@ -343,4 +340,16 @@ mod tests { let result = outputs_buffer.read_to_vec::(right.len()); assert_eq!(result, expected); } + + #[test] + fn cos_f16() { + let v: Vec = [1.0f32, 2.0, 3.0] + .iter() + .map(|v| f16::from_f32(*v)) + .collect(); + let results = run_cos(&v, "half"); + let expected: Vec = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect(); + assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]); + assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]); + } } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 3861b2f0..a2635b68 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -2,13 +2,21 @@ using namespace metal; -struct Input { - device float *input; - device float *output; -}; - -kernel void cos(device Input& args [[ buffer(0) ]], uint index [[thread_position_in_grid]]) +template +kernel void unary_cos(device const T *input, device T *output, uint index [[thread_position_in_grid]]) { - args.output[index] = cos(args.input[index]); + output[index] = cos(input[index]); } +#define UNARY(FN, TYPENAME, FN_NAME) \ +kernel void FN_NAME(device const TYPENAME *input, device TYPENAME *output, uint index [[thread_position_in_grid]]) \ +{ \ + output[index] = FN(input[index]);\ +} + +UNARY(cos, float, cos_float); +UNARY(cos, half, cos_half); + +#if __METAL_VERSION__ >= 310 +UNARY(cos, half, cos_half); +#endif