diff --git a/candle-core/benches/benchmarks/reduce.rs b/candle-core/benches/benchmarks/reduce.rs index 8216d9d0..1be1a0f4 100644 --- a/candle-core/benches/benchmarks/reduce.rs +++ b/candle-core/benches/benchmarks/reduce.rs @@ -1,6 +1,8 @@ use crate::benchmarks::{bench_name, device, BenchDevice}; -use candle_core::{DType, Device, Tensor}; +use candle_core::{DType, Device, Storage, Tensor}; use criterion::{black_box, criterion_group, Criterion, Throughput}; +use half::{bf16, f16}; +use std::ops::Deref; use std::time::Instant; fn run_sum(a: &Tensor) { @@ -10,21 +12,114 @@ fn run_arg_min(a: &Tensor) { a.argmin(2).unwrap(); } +fn softmax(a: &Tensor) -> candle_core::Result<()> { + use candle_core::{backend::BackendStorage, DType}; + let (storage, layout) = a.storage_and_layout(); + + let device = a.device(); + + if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) { + let command_buffer = device.command_buffer()?; + let kernels = device.kernels(); + let name = match a.dtype() { + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", + dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"), + }; + + let n = layout.stride().len(); + if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) { + candle_core::bail!("Non contiguous softmax-last-dim is not implemented"); + } + + let last_dim = layout.dims()[layout.shape().rank() - 1]; + let elem_count = layout.shape().elem_count(); + let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; + candle_metal_kernels::call_last_softmax( + device.metal_device(), + &command_buffer, + kernels, + name, + elem_count, + last_dim, + storage.buffer(), + layout.start_offset() * storage.dtype().size_in_bytes(), + &output, + ) + .unwrap(); + } + Ok(()) +} + fn criterion_benchmark(c: &mut Criterion) { let device = device().unwrap(); - run_reduce(c, &device); - run_arg_reduce(c, &device); + + let (lo, up) = (-1000.0f32, 1000.0f32); + run_softmax(c, &device, (lo, up)); + run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up))); + run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up))); + + run_reduce(c, &device, (lo, up)); + run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up))); + run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up))); + + run_arg_reduce(c, &device, (lo, up)); + run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up))); + run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up))); } -fn run_reduce(c: &mut Criterion, device: &Device) { + +fn run_softmax(c: &mut Criterion, device: &Device, (lo, up): (T, T)) { + if !device.is_metal() { + return; + } + + let b = 1; + let m = 2048; + let k = 2048; + let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap(); + + let flops = b * m * k * T::DTYPE.size_in_bytes(); + + let name = match T::DTYPE { + DType::F32 => "softmax_f32", + DType::F16 => "softmax_f16", + DType::BF16 => "softmax_bf16", + _ => "softmax", + }; + + let mut group = c.benchmark_group(bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + softmax(black_box(&a)).unwrap(); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn run_reduce(c: &mut Criterion, device: &Device, (lo, up): (T, T)) { let b = 1; let m = 2048; let k = 2048; - let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap(); + let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap(); - let flops = b * m * k * DType::F32.size_in_bytes(); + let flops = b * m * k * T::DTYPE.size_in_bytes(); - let mut group = c.benchmark_group(bench_name("reduce")); + let name = match T::DTYPE { + DType::F32 => "reduce_f32", + DType::F16 => "reduce_f16", + DType::BF16 => "reduce_bf16", + _ => "reduce", + }; + + let mut group = c.benchmark_group(bench_name(name)); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { @@ -39,16 +134,27 @@ fn run_reduce(c: &mut Criterion, device: &Device) { group.finish(); } -fn run_arg_reduce(c: &mut Criterion, device: &Device) { +fn run_arg_reduce( + c: &mut Criterion, + device: &Device, + (lo, up): (T, T), +) { let b = 1; let m = 2048; let k = 2048; - let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap(); + let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap(); - let flops = b * m * k * DType::F32.size_in_bytes(); + let flops = b * m * k * T::DTYPE.size_in_bytes(); - let mut group = c.benchmark_group(bench_name("arg_reduce")); + let name = match T::DTYPE { + DType::F32 => "arg_reduce_f32", + DType::F16 => "arg_reduce_f16", + DType::BF16 => "arg_reduce_bf16", + _ => "reduce", + }; + + let mut group = c.benchmark_group(bench_name(name)); group.throughput(Throughput::Bytes(flops as u64)); group.bench_function("iter", move |b| { b.iter_custom(|iters| { diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 1eeb53c0..f1284aae 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -545,8 +545,8 @@ impl BackendStorage for MetalStorage { if check_empty && layout.shape().elem_count() == 0 { Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? } - - let buffer = device.new_buffer(1, self.dtype, "reduce")?; + let dtype = if return_index { DType::U32 } else { self.dtype }; + let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_reduce_contiguous( &device.device, diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 5c8963cd..bb9a1b91 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -74,7 +74,7 @@ METAL_FUNC void load_from_global( shared[tid] = op(shared[tid], src[strided_i]); idx += block_dim; } - threadgroup_barrier(mem_flags::mem_none); + threadgroup_barrier(mem_flags::mem_threadgroup); } // Load strided elements from global memory into shared memory with indices. @@ -107,7 +107,7 @@ METAL_FUNC void load_from_global( } idx += block_dim; } - threadgroup_barrier(mem_flags::mem_none); + threadgroup_barrier(mem_flags::mem_threadgroup); } // Load contiguous elements from global memory into shared memory. @@ -131,7 +131,7 @@ METAL_FUNC void load_from_global( shared[tid] = op(shared[tid], src[idx]); idx += block_dim; } - threadgroup_barrier(mem_flags::mem_none); + threadgroup_barrier(mem_flags::mem_threadgroup); } // Load contiguous elements from global memory into shared memory with indices. @@ -162,15 +162,15 @@ METAL_FUNC void load_from_global( } idx += block_dim; } - threadgroup_barrier(mem_flags::mem_none); + threadgroup_barrier(mem_flags::mem_threadgroup); } #define reduce_threadgroup(SIZE) \ if (BLOCKSIZE >= SIZE) { \ if (block_dim >= SIZE) { \ shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \ - threadgroup_barrier(mem_flags::mem_none); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } template @@ -196,8 +196,8 @@ if (BLOCKSIZE >= SIZE) { \ ) { \ shared_indices[tid] = shared_indices[tid + SIZE / 2]; \ shared[tid] = shared[tid + SIZE / 2]; \ - threadgroup_barrier(mem_flags::mem_none); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } template @@ -221,8 +221,8 @@ METAL_FUNC void threadgroup_reduce( if (BLOCKSIZE >= SIZE) { \ if (tid < SIZE / 2 && block_dim >= SIZE) { \ shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \ - threadgroup_barrier(mem_flags::mem_none); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ // Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris @@ -282,8 +282,8 @@ METAL_FUNC void block_reduce( if (tid < 32) { threadgroup_reduce(shared, tid, block_dim); - threadgroup_barrier(mem_flags::mem_none); } + threadgroup_barrier(mem_flags::mem_threadgroup); if (tid == 0) { dst[dst_id] = shared[tid]; } @@ -358,8 +358,8 @@ if (BLOCKSIZE >= SIZE) { \ ) { \ shared_indices[tid] = shared_indices[tid + SIZE / 2]; \ shared[tid] = shared[tid + SIZE / 2]; \ - threadgroup_barrier(mem_flags::mem_none); \ } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ } \ template< @@ -420,7 +420,7 @@ METAL_FUNC void arg_block_reduce( if (tid < 32) { threadgroup_reduce(shared, shared_indices, tid, block_dim); - threadgroup_barrier(mem_flags::mem_none); + threadgroup_barrier(mem_flags::mem_threadgroup); } if (tid == 0) { @@ -491,71 +491,121 @@ kernel void NAME##_strided( \ #define MAX(x, y) ((x) > (y) ? (x) : (y)) #define MIN(x, y) ((x) < (y) ? (x) : (y)) -#define SOFTMAX(NAME, T) \ -kernel void NAME( \ - constant size_t &src_numel, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - threadgroup float shared_memory[THREADGROUP_SIZE]; \ - shared_memory[tid] = -INFINITY; \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ - size_t idx = start_idx + tid; \ - \ - \ - float tmp = -INFINITY; \ - while (idx < stop_idx) { \ - tmp = MAX(tmp, float(src[idx])); \ - idx += block_dim; \ - } \ - shared_memory[tid] = tmp; \ - \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - } \ - \ - /* wait for shared_memory[0] to be filled */ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - \ - float _max = shared_memory[0]; \ - \ - /* prevent tid=0 from overwriting _max before other threads have written it */ \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - shared_memory[tid] = 0; \ - \ - idx = start_idx + tid; \ - while (idx < stop_idx) { \ - const float val = exp(float(src[idx]) - _max); \ - dst[idx] = T(val); \ - shared_memory[tid] += val; \ - idx += block_dim; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - shared_memory[tid] += shared_memory[tid + s]; \ - } \ - threadgroup_barrier(mem_flags::mem_threadgroup); \ - } \ - \ - const T inv_acc = T(1.0/shared_memory[0]); \ - idx = start_idx + tid; \ - while (idx < stop_idx) { \ - dst[idx] *= inv_acc; \ - idx += block_dim; \ - } \ + +#define softmax_max_block(SIZE) \ +if (BLOCKSIZE >= SIZE) { \ + if (tid < SIZE / 2 && block_dim >= SIZE) { \ + shared[tid] = max_op(shared[tid], shared[tid + SIZE / 2]); \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ +} + +#define softmax_acc_block(SIZE) \ +if (BLOCKSIZE >= SIZE) { \ + if (tid < SIZE / 2 && block_dim >= SIZE) { \ + shared[tid] += shared[tid + SIZE / 2]; \ + } \ + threadgroup_barrier(mem_flags::mem_threadgroup); \ +} + +template< + typename T, + typename ACC, + uint BLOCKSIZE +> +METAL_FUNC void softmax( + constant size_t &src_numel, + constant size_t &el_to_sum_per_block, + device const T *src, + device T *dst, + threadgroup ACC shared[BLOCKSIZE], + + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint block_dim [[ threads_per_threadgroup ]] +) { + Max max_op; + + shared[tid] = numeric_limits::min(); + ACC tmp = numeric_limits::min(); + + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); + size_t idx = start_idx + tid; + + while (idx < stop_idx) { + tmp = max_op(tmp, static_cast(src[idx])); + idx += block_dim; + } + shared[tid] = tmp; + threadgroup_barrier(mem_flags::mem_threadgroup); + + softmax_max_block(1024); + softmax_max_block(512); + softmax_max_block(256); + softmax_max_block(128); + if (tid < 32) { + threadgroup_reduce, BLOCKSIZE>(shared, tid, block_dim); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + ACC _max = shared[0]; + + // prevent tid 0 from overwriting _max before other threads have written + threadgroup_barrier(mem_flags::mem_threadgroup); + shared[tid] = 0; + + idx = start_idx + tid; + while (idx < stop_idx) { + const ACC val = exp(static_cast(src[idx]) - _max); + dst[idx] = static_cast(val); + shared[tid] += val; + + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + softmax_acc_block(1024); + softmax_acc_block(512); + softmax_acc_block(256); + softmax_acc_block(128); + if (tid < 32) { + threadgroup_reduce, BLOCKSIZE>(shared, tid, block_dim); + threadgroup_barrier(mem_flags::mem_none); + } + + const T inv_acc = T(1.0/shared[0]); + idx = start_idx + tid; + while (idx < stop_idx) { + dst[idx] *= inv_acc; + idx += block_dim; + } +} + + +#define SOFTMAX(NAME, T, ACC) \ +kernel void NAME( \ + constant size_t &src_numel, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup ACC shared_memory[BLOCKSIZE]; \ + softmax( \ + src_numel, \ + el_to_sum_per_block, \ + src, \ + dst, \ + shared_memory, \ + id, \ + tid, \ + dst_id, \ + block_dim); \ } REDUCE(Sum, fast_sum_f32, float) @@ -588,8 +638,8 @@ ARG_REDUCE(ArgMax, fast_argmax_f16, half) ARG_REDUCE(ArgMax, fast_argmax_u32, uint) ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t) -SOFTMAX(softmax_f32, float) -SOFTMAX(softmax_f16, half) +SOFTMAX(softmax_f32, float, float) +SOFTMAX(softmax_f16, half, float) #if __METAL_VERSION__ >= 220 REDUCE(Sum, fast_sum_i64, int64_t) @@ -611,5 +661,5 @@ REDUCE(Min, fast_min_bf16, bfloat) ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat) ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat) -SOFTMAX(softmax_bf16, bfloat) +SOFTMAX(softmax_bf16, bfloat, float) #endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 5d41977f..c16552f7 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -529,7 +529,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { println!("Error: {}", e); panic!("damn!"); - }, + } } read_to_vec(&output, out_length) @@ -597,7 +597,6 @@ fn softmax() { } let results = run_softmax(&v, last_dim, "softmax_f32"); let results = approx(results, 4); - println!("{results:?}"); assert_eq!( results.iter().map(|&s| s.round() as usize).sum::(), n