Improve softmax kernel. 33%-39% higher thrpt

This commit is contained in:
Ivar Flakstad
2024-01-22 18:25:52 +01:00
parent 1f4c54493e
commit 2056866c25
4 changed files with 248 additions and 93 deletions

View File

@ -1,6 +1,8 @@
use crate::benchmarks::{bench_name, device, BenchDevice};
use candle_core::{DType, Device, Tensor};
use candle_core::{DType, Device, Storage, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use half::{bf16, f16};
use std::ops::Deref;
use std::time::Instant;
fn run_sum(a: &Tensor) {
@ -10,21 +12,114 @@ fn run_arg_min(a: &Tensor) {
a.argmin(2).unwrap();
}
fn softmax(a: &Tensor) -> candle_core::Result<()> {
use candle_core::{backend::BackendStorage, DType};
let (storage, layout) = a.storage_and_layout();
let device = a.device();
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match a.dtype() {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
};
let n = layout.stride().len();
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
}
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
)
.unwrap();
}
Ok(())
}
fn criterion_benchmark(c: &mut Criterion) {
let device = device().unwrap();
run_reduce(c, &device);
run_arg_reduce(c, &device);
let (lo, up) = (-1000.0f32, 1000.0f32);
run_softmax(c, &device, (lo, up));
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_reduce(c, &device, (lo, up));
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
run_arg_reduce(c, &device, (lo, up));
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
}
fn run_reduce(c: &mut Criterion, device: &Device) {
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
if !device.is_metal() {
return;
}
let b = 1;
let m = 2048;
let k = 2048;
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let flops = b * m * k * T::DTYPE.size_in_bytes();
let name = match T::DTYPE {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
_ => "softmax",
};
let mut group = c.benchmark_group(bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
softmax(black_box(&a)).unwrap();
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
let b = 1;
let m = 2048;
let k = 2048;
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let flops = b * m * k * DType::F32.size_in_bytes();
let flops = b * m * k * T::DTYPE.size_in_bytes();
let mut group = c.benchmark_group(bench_name("reduce"));
let name = match T::DTYPE {
DType::F32 => "reduce_f32",
DType::F16 => "reduce_f16",
DType::BF16 => "reduce_bf16",
_ => "reduce",
};
let mut group = c.benchmark_group(bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
@ -39,16 +134,27 @@ fn run_reduce(c: &mut Criterion, device: &Device) {
group.finish();
}
fn run_arg_reduce(c: &mut Criterion, device: &Device) {
fn run_arg_reduce<T: candle_core::FloatDType>(
c: &mut Criterion,
device: &Device,
(lo, up): (T, T),
) {
let b = 1;
let m = 2048;
let k = 2048;
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
let flops = b * m * k * DType::F32.size_in_bytes();
let flops = b * m * k * T::DTYPE.size_in_bytes();
let mut group = c.benchmark_group(bench_name("arg_reduce"));
let name = match T::DTYPE {
DType::F32 => "arg_reduce_f32",
DType::F16 => "arg_reduce_f16",
DType::BF16 => "arg_reduce_bf16",
_ => "reduce",
};
let mut group = c.benchmark_group(bench_name(name));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {

View File

@ -545,8 +545,8 @@ impl BackendStorage for MetalStorage {
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
let dtype = if return_index { DType::U32 } else { self.dtype };
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous(
&device.device,

View File

@ -74,7 +74,7 @@ METAL_FUNC void load_from_global(
shared[tid] = op(shared[tid], src[strided_i]);
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Load strided elements from global memory into shared memory with indices.
@ -107,7 +107,7 @@ METAL_FUNC void load_from_global(
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Load contiguous elements from global memory into shared memory.
@ -131,7 +131,7 @@ METAL_FUNC void load_from_global(
shared[tid] = op(shared[tid], src[idx]);
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Load contiguous elements from global memory into shared memory with indices.
@ -162,15 +162,15 @@ METAL_FUNC void load_from_global(
}
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
#define reduce_threadgroup(SIZE) \
if (BLOCKSIZE >= SIZE) { \
if (block_dim >= SIZE) { \
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
threadgroup_barrier(mem_flags::mem_none); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
}
template<typename T, typename ReductionOp, uint BLOCKSIZE>
@ -196,8 +196,8 @@ if (BLOCKSIZE >= SIZE) { \
) { \
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
shared[tid] = shared[tid + SIZE / 2]; \
threadgroup_barrier(mem_flags::mem_none); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
}
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
@ -221,8 +221,8 @@ METAL_FUNC void threadgroup_reduce(
if (BLOCKSIZE >= SIZE) { \
if (tid < SIZE / 2 && block_dim >= SIZE) { \
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
threadgroup_barrier(mem_flags::mem_none); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
@ -282,8 +282,8 @@ METAL_FUNC void block_reduce(
if (tid < 32) {
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
threadgroup_barrier(mem_flags::mem_none);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
dst[dst_id] = shared[tid];
}
@ -358,8 +358,8 @@ if (BLOCKSIZE >= SIZE) { \
) { \
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
shared[tid] = shared[tid + SIZE / 2]; \
threadgroup_barrier(mem_flags::mem_none); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
template<
@ -420,7 +420,7 @@ METAL_FUNC void arg_block_reduce(
if (tid < 32) {
threadgroup_reduce<T, ArgReductionOp, BLOCKSIZE>(shared, shared_indices, tid, block_dim);
threadgroup_barrier(mem_flags::mem_none);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (tid == 0) {
@ -491,71 +491,121 @@ kernel void NAME##_strided( \
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = -INFINITY; \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
\
float tmp = -INFINITY; \
while (idx < stop_idx) { \
tmp = MAX(tmp, float(src[idx])); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
idx += block_dim; \
} \
#define softmax_max_block(SIZE) \
if (BLOCKSIZE >= SIZE) { \
if (tid < SIZE / 2 && block_dim >= SIZE) { \
shared[tid] = max_op(shared[tid], shared[tid + SIZE / 2]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
}
#define softmax_acc_block(SIZE) \
if (BLOCKSIZE >= SIZE) { \
if (tid < SIZE / 2 && block_dim >= SIZE) { \
shared[tid] += shared[tid + SIZE / 2]; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
}
template<
typename T,
typename ACC,
uint BLOCKSIZE
>
METAL_FUNC void softmax(
constant size_t &src_numel,
constant size_t &el_to_sum_per_block,
device const T *src,
device T *dst,
threadgroup ACC shared[BLOCKSIZE],
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint block_dim [[ threads_per_threadgroup ]]
) {
Max<ACC> max_op;
shared[tid] = numeric_limits<ACC>::min();
ACC tmp = numeric_limits<ACC>::min();
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
while (idx < stop_idx) {
tmp = max_op(tmp, static_cast<ACC>(src[idx]));
idx += block_dim;
}
shared[tid] = tmp;
threadgroup_barrier(mem_flags::mem_threadgroup);
softmax_max_block(1024);
softmax_max_block(512);
softmax_max_block(256);
softmax_max_block(128);
if (tid < 32) {
threadgroup_reduce<ACC, Max<ACC>, BLOCKSIZE>(shared, tid, block_dim);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
ACC _max = shared[0];
// prevent tid 0 from overwriting _max before other threads have written
threadgroup_barrier(mem_flags::mem_threadgroup);
shared[tid] = 0;
idx = start_idx + tid;
while (idx < stop_idx) {
const ACC val = exp(static_cast<ACC>(src[idx]) - _max);
dst[idx] = static_cast<T>(val);
shared[tid] += val;
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
softmax_acc_block(1024);
softmax_acc_block(512);
softmax_acc_block(256);
softmax_acc_block(128);
if (tid < 32) {
threadgroup_reduce<ACC, Sum<ACC>, BLOCKSIZE>(shared, tid, block_dim);
threadgroup_barrier(mem_flags::mem_none);
}
const T inv_acc = T(1.0/shared[0]);
idx = start_idx + tid;
while (idx < stop_idx) {
dst[idx] *= inv_acc;
idx += block_dim;
}
}
#define SOFTMAX(NAME, T, ACC) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
\
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup ACC shared_memory[BLOCKSIZE]; \
softmax<T, ACC, BLOCKSIZE>( \
src_numel, \
el_to_sum_per_block, \
src, \
dst, \
shared_memory, \
id, \
tid, \
dst_id, \
block_dim); \
}
REDUCE(Sum, fast_sum_f32, float)
@ -588,8 +638,8 @@ ARG_REDUCE(ArgMax, fast_argmax_f16, half)
ARG_REDUCE(ArgMax, fast_argmax_u32, uint)
ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
SOFTMAX(softmax_f32, float, float)
SOFTMAX(softmax_f16, half, float)
#if __METAL_VERSION__ >= 220
REDUCE(Sum, fast_sum_i64, int64_t)
@ -611,5 +661,5 @@ REDUCE(Min, fast_min_bf16, bfloat)
ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat)
ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat)
SOFTMAX(softmax_bf16, bfloat)
SOFTMAX(softmax_bf16, bfloat, float)
#endif

View File

@ -529,7 +529,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
Err(e) => {
println!("Error: {}", e);
panic!("damn!");
},
}
}
read_to_vec(&output, out_length)
@ -597,7 +597,6 @@ fn softmax() {
}
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4);
println!("{results:?}");
assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n