diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index c388c04e..e5f0a841 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -24,17 +24,14 @@ kernel void FN_NAME( \ constant float &add, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint id [[ thread_position_in_grid ]] \ ) { \ + if (id >= dim) { \ + return; \ + } \ const TYPENAME m = TYPENAME(mul); \ const TYPENAME a = TYPENAME(add); \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = input[i] * m + a; \ - } \ + output[id] = input[id] * m + a; \ } \ AFFINE(affine_float, float) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index cfd34416..37bc0bae 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -23,17 +23,14 @@ kernel void FN_NAME( \ device const TYPENAME *left, \ device const TYPENAME *right, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - TYPENAME x = left[i]; \ - TYPENAME y = right[i]; \ - output[i] = OUT_TYPENAME(FN); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + TYPENAME x = left[thread_position_in_grid]; \ + TYPENAME y = right[thread_position_in_grid]; \ + output[thread_position_in_grid] = OUT_TYPENAME(FN); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -44,17 +41,14 @@ kernel void FN_NAME_STRIDED( \ device const TYPENAME *left, \ device const TYPENAME *right, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \ - TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \ - output[i] = OUT_TYPENAME(FN); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ + TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \ + output[thread_position_in_grid] = OUT_TYPENAME(FN); \ } #define BINARY_OP(FN, NAME) \ diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 52e63662..d1788253 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -23,15 +23,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = RIGHT_TYPENAME(input[i]); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + output[thread_position_in_grid] = RIGHT_TYPENAME(input[thread_position_in_grid]); \ } \ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -40,17 +37,13 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const LEFT_TYPENAME *input, \ device RIGHT_TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint i [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ + if (i >= dim) { \ + return; \ } \ -} - + output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \ +} \ CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index c077cc48..444fa322 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -2,7 +2,7 @@ using namespace metal; # define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \ -kernel void NAME( \ +kernel void NAME( \ constant size_t &dst_size, \ constant size_t &left_size, \ constant size_t &src_dim_size, \ @@ -42,12 +42,9 @@ void index_add( constant uint &dst_dim_size, constant uint &right_size, - uint threadgroup_size [[threads_per_threadgroup]], - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], - uint thread_index [[thread_index_in_threadgroup]] + uint gid [[ thread_position_in_grid ]] \ ) { - const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size); if (gid >= left_size * right_size) { return; } @@ -73,14 +70,13 @@ kernel void FN_NAME( \ constant uint &left_size, \ constant uint &dst_dim_size, \ constant uint &right_size, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ - uint thread_index [[thread_index_in_threadgroup]] \ -) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \ + uint gid [[ thread_position_in_grid ]] \ +) { index_add(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \ INDEX_OP(is_u32_f32, uint, float) + #if __METAL_VERSION__ >= 310 IA_OP(bfloat, int64_t, ia_i64_bf16) IA_OP(bfloat, uint32_t, ia_u32_bf16) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 6a01107c..83fbe833 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1,7 +1,7 @@ #![allow(clippy::too_many_arguments)] use metal::{ Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor, - Device, Function, Library, MTLSize, + ComputePipelineState, Device, Function, Library, MTLSize, }; use std::collections::HashMap; use std::ffi::c_void; @@ -15,6 +15,24 @@ const TERNARY: &str = include_str!("ternary.metal"); const CAST: &str = include_str!("cast.metal"); const REDUCE: &str = include_str!("reduce.metal"); +fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) { + let size = length as u64; + let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size); + let count = (size + width - 1) / width; + let thread_group_count = MTLSize { + width: count, + height: 1, + depth: 1, + }; + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + (thread_group_count, thread_group_size) +} + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } @@ -257,19 +275,7 @@ pub fn call_unary_contiguous( set_params!(encoder, (length, input, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) @@ -314,17 +320,7 @@ pub fn call_unary_strided( ); let width: usize = shape.iter().product(); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -358,18 +354,7 @@ pub fn call_binary_contiguous( set_params!(encoder, (length, left, right, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -421,17 +406,7 @@ pub fn call_binary_strided( ) ); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let thread_group_size = MTLSize { - width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width as u64), - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, width); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -464,18 +439,7 @@ pub fn call_cast_contiguous( set_params!(encoder, (length, input, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -608,19 +572,7 @@ pub fn call_affine( set_params!(encoder, (size, mul, add, input, output)); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) @@ -672,18 +624,7 @@ pub fn call_where_cond_strided( ) ); - let thread_group_count = MTLSize { - width: 1, - height: 1, - depth: 1, - }; - - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); @@ -730,19 +671,9 @@ pub fn call_index_select( ) ); - let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dst_el as u64); - let grid_size = MTLSize { - width: (dst_el as u64 + width - 1) / width, - height: 1, - depth: 1, - }; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - let thread_group_size = MTLSize { - width, - height: 1, - depth: 1, - }; - encoder.dispatch_thread_groups(grid_size, thread_group_size); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.end_encoding(); Ok(()) } diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index 77de214e..dd137599 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -27,15 +27,12 @@ kernel void FN_NAME( \ constant size_t &dim, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = TYPENAME(FN(input[i])); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \ }\ kernel void FN_NAME_STRIDED( \ constant size_t &dim, \ @@ -44,15 +41,12 @@ kernel void FN_NAME_STRIDED( \ constant size_t *strides, \ device const TYPENAME *input, \ device TYPENAME *output, \ - uint threadgroup_size [[threads_per_threadgroup]], \ - uint thread_index [[thread_index_in_threadgroup]] \ + uint thread_position_in_grid [[ thread_position_in_grid ]] \ ) { \ - const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \ - const size_t start = thread_index * length; \ - const size_t stop = min(start + length, dim); \ - for (size_t i = start; i < stop; i++){ \ - output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \ + if (thread_position_in_grid >= dim) { \ + return; \ } \ + output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \ } #define UNARY_OP(NAME) \ @@ -79,4 +73,6 @@ BFLOAT_UNARY_OP(sqr) BFLOAT_UNARY_OP(sqrt) BFLOAT_UNARY_OP(neg) BFLOAT_UNARY_OP(exp) + +UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided) #endif