diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 424b29d9..047313d1 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -482,20 +482,9 @@ impl BackendStorage for MetalStorage { } fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { - if sum_dims.len() != 1 { - crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet."); - } - if sum_dims[0] != layout.shape().rank() - 1 { - crate::bail!("Non last dim reduce op {op:?} not implemented yet"); - } - if layout.stride()[sum_dims[0]] != 1 { - crate::bail!("Non contiguous reduce op {op:?} not implemented yet"); - } - let device = self.device.clone(); let src_stride = layout.stride(); let src_dims = layout.shape().dims(); - let src_el: usize = src_dims.iter().product(); // Source dims and strides with the sum dims at the end. let mut dims = vec![]; let mut stride = vec![]; @@ -515,28 +504,41 @@ impl BackendStorage for MetalStorage { // The reduction loop requires the shared array to be properly initialized and for // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { - (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), - (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), - (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), - (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), - (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), - _ => crate::bail!("Reduce op for non float"), + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false), + (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true), + (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true), + (ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false), + (ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false), + (ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false), + (ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true), + (ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true), + (ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false), + (ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false), + (ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false), + (ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true), + (ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true), + (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false), + (ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false), + (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), + (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), + (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"), }; if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } let dtype = if return_index { DType::U32 } else { self.dtype }; - if dtype == DType::U32 { - crate::bail!("reduce op {name} is not implemented yet."); - } let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; - candle_metal_kernels::call_reduce_contiguous( + candle_metal_kernels::call_reduce_strided( &device.device, &command_buffer, &device.kernels, name, - src_el, + &dims, + &stride, dst_el, &self.buffer, layout.start_offset() * self.dtype.size_in_bytes(), @@ -730,7 +732,7 @@ impl BackendStorage for MetalStorage { ("sub", DType::F16) => contiguous::sub::HALF, ("mul", DType::F16) => contiguous::mul::HALF, ("div", DType::F16) => contiguous::div::HALF, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), }; candle_metal_kernels::call_binary_contiguous( &device.device, @@ -751,11 +753,15 @@ impl BackendStorage for MetalStorage { ("bsub", DType::F32) => strided::sub::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT, + ("bminimum", DType::F32) => strided::min::FLOAT, + ("bmaximum", DType::F32) => strided::max::FLOAT, ("badd", DType::F16) => strided::add::HALF, ("bsub", DType::F16) => strided::sub::HALF, ("bmul", DType::F16) => strided::mul::HALF, ("bdiv", DType::F16) => strided::div::HALF, - (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), + ("bminimum", DType::F16) => strided::min::HALF, + ("bmaximum", DType::F16) => strided::max::HALF, + (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"), }; candle_metal_kernels::call_binary_strided( &device.device, diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c871dc96..06891748 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -543,6 +543,7 @@ fn argmax(device: &Device) -> Result<()> { let t1 = tensor.reshape((190, 5, 4))?; let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; for tensor in [t1, t2] { + println!("{}", tensor.argmax_keepdim(0)?.argmax_keepdim(2)?); assert_eq!( tensor .argmax_keepdim(0)? diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index ea21bb34..f13589c1 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -1,5 +1,8 @@ #include +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -63,10 +66,14 @@ BINARY_OP(x + y, add) BINARY_OP(x - y, sub) BINARY_OP(x * y, mul) BINARY_OP(x / y, div) +BINARY_OP(MIN(x, y), min) +BINARY_OP(MAX(x, y), max) #if __METAL_VERSION__ >= 310 BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x / y, div) +BFLOAT_BINARY_OP(MIN(x, y), min) +BFLOAT_BINARY_OP(MAX(x, y), max) #endif diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f2db171e..c34e34fe 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -166,7 +166,7 @@ pub mod unary { ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); } pub mod binary { - ops!(add, sub, mul, div); + ops!(add, sub, mul, div, min, max); } #[derive(thiserror::Error, Debug)] @@ -588,6 +588,64 @@ pub fn call_reduce_contiguous( Ok(()) } +pub fn call_reduce_strided( + device: &Device, + command_buffer: &CommandBufferRef, + kernels: &Kernels, + kernel_name: &'static str, + shape: &[usize], + strides: &[usize], + out_length: usize, + input: &Buffer, + input_offset: usize, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let length: usize = shape.iter().product(); + let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; + let elements_to_sum = length / out_length; + + let encoder = command_buffer.new_compute_command_encoder(); + encoder.wait_for_fence(&kernels.fence); + encoder.set_compute_pipeline_state(&pipeline); + + set_params!( + encoder, + ( + shape.len(), + shape, + strides, + elements_to_sum, + (input, input_offset), + output + ) + ); + + let thread_group_count = MTLSize { + width: out_length as u64, + height: 1, + depth: 1, + }; + + let width = std::cmp::min( + pipeline.max_total_threads_per_threadgroup(), + elements_to_sum as u64, + ) + .next_power_of_two(); + + let thread_group_size = MTLSize { + width, + height: 1, + depth: 1, + }; + + encoder.use_resource(input, metal::MTLResourceUsage::Read); + encoder.use_resource(output, metal::MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + encoder.update_fence(&kernels.fence); + encoder.end_encoding(); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_last_softmax( device: &Device, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 62443660..2d584917 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -2,6 +2,7 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) METAL_FUNC uint get_strided_index( uint idx, @@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index( constant int THREADGROUP_SIZE = 2048; -# define REDUCE(FN, NAME, T) \ + +#define ARGMIN(NAME, T, MAXVALUE) \ kernel void NAME( \ - constant size_t &src_numel, \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MAXVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + bool notset = true; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || src[strided_i] < shared_memory[tid]) { \ + shared_memory[tid] = src[strided_i]; \ + /* Assume that the reduction takes place over the last dimension which is contiguous. */ \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + + +#define ARGMAX(NAME, T, MINVALUE) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device uint *dst, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + \ + threadgroup T shared_memory[THREADGROUP_SIZE]; \ + threadgroup uint shared_indices[THREADGROUP_SIZE]; \ + \ + shared_memory[tid] = MINVALUE; \ + shared_indices[tid] = 0xFFFFFFFF; \ + /* \ + // Elements summed in this block range from dst_id * el_to_sum_per_block \ + // to (dst_id + 1) * el_to_sum_per_block. \ + */ \ + size_t start_idx = dst_id * el_to_sum_per_block; \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ + size_t idx = start_idx + tid; \ + bool notset = true; \ + while (idx < stop_idx) { \ + /* \ + // TODO: Fast version for the contiguous case. \ + */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ + if (notset || shared_memory[tid] < src[strided_i]) { \ + shared_memory[tid] = src[strided_i]; \ + shared_indices[tid] = idx % dims[num_dims - 1]; \ + notset = false; \ + } \ + idx += block_dim; \ + } \ + \ + threadgroup_barrier(mem_flags::mem_none); \ + \ + /* \ + // reduction in shared memory \ + */ \ + for (uint s = block_dim / 2; s > 0; s >>= 1) { \ + if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \ + shared_indices[tid] = shared_indices[tid + s]; \ + shared_memory[tid] = shared_memory[tid + s]; \ + } \ + threadgroup_barrier(mem_flags::mem_none); \ + } \ + \ + if (tid == 0){ \ + dst[dst_id] = shared_indices[0]; \ + } \ +} \ + +#define REDUCE(FN, NAME, T, START) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ constant size_t &el_to_sum_per_block, \ device const T *src, \ device T *dst, \ @@ -34,21 +156,21 @@ kernel void NAME( \ \ threadgroup T shared_memory[THREADGROUP_SIZE]; \ \ - shared_memory[tid] = 0; \ + shared_memory[tid] = START; \ /* \ // Elements summed in this block range from dst_id * el_to_sum_per_block \ // to (dst_id + 1) * el_to_sum_per_block. \ */ \ size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ + size_t stop_idx = start_idx + el_to_sum_per_block; \ size_t idx = start_idx + tid; \ while (idx < stop_idx) { \ /* \ // TODO: Fast version for the contiguous case. \ - // size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ */ \ + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ T x = shared_memory[tid]; \ - T y = src[idx]; \ + T y = src[strided_i]; \ shared_memory[tid] = FN; \ idx += block_dim; \ } \ @@ -71,10 +193,6 @@ kernel void NAME( \ } \ -REDUCE(x + y, fast_sum_f32, float) -REDUCE(x * y, fast_mul_f32, float) -REDUCE(max(x, y), fast_max_f32, float) - #define SOFTMAX(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -142,8 +260,33 @@ kernel void NAME( } \ } \ +REDUCE(x + y, fast_sum_f32_strided, float, 0) +REDUCE(x + y, fast_sum_u32_strided, uint, 0) +REDUCE(x + y, fast_sum_f16_strided, half, 0) +REDUCE(x * y, fast_mul_f32_strided, float, 1) +REDUCE(x * y, fast_mul_u32_strided, uint, 1) +REDUCE(x * y, fast_mul_f16_strided, half, 1) +REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) +REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) +REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) +REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) +REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) +REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) +ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) +ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) +ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF) +ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH) +ARGMAX(fast_argmax_u32_strided, uint, 0) + SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) #if __METAL_VERSION__ >= 310 +REDUCE(x + y, fast_sum_bf16, bfloat, 0) +REDUCE(x * y, fast_mul_bf16, bfloat, 1) +REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) +REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) +ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) #endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 9c9475a2..8d5a2624 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -574,12 +574,16 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); - call_reduce_contiguous( + let num_dims = 1; + let dims = vec![v.len()]; + let strides = vec![1]; + call_reduce_strided( &device, command_buffer, &kernels, name, - v.len(), + &dims, + &strides, out_length, &input, 0, @@ -623,7 +627,7 @@ fn reduce_sum() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 1; - let results = run_reduce(&v, out_length, "fast_sum_f32"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![21.0]); } @@ -632,7 +636,7 @@ fn reduce_sum2() { let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let out_length = 2; - let results = run_reduce(&v, out_length, "fast_sum_f32"); + let results = run_reduce(&v, out_length, "fast_sum_f32_strided"); assert_eq!(approx(results, 4), vec![6.0, 15.0]); }