mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 20:38:06 +00:00
Improve softmax kernel. 33%-39% higher thrpt
This commit is contained in:
@ -1,6 +1,8 @@
|
|||||||
use crate::benchmarks::{bench_name, device, BenchDevice};
|
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||||
use candle_core::{DType, Device, Tensor};
|
use candle_core::{DType, Device, Storage, Tensor};
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::ops::Deref;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
fn run_sum(a: &Tensor) {
|
fn run_sum(a: &Tensor) {
|
||||||
@ -10,21 +12,114 @@ fn run_arg_min(a: &Tensor) {
|
|||||||
a.argmin(2).unwrap();
|
a.argmin(2).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn softmax(a: &Tensor) -> candle_core::Result<()> {
|
||||||
|
use candle_core::{backend::BackendStorage, DType};
|
||||||
|
let (storage, layout) = a.storage_and_layout();
|
||||||
|
|
||||||
|
let device = a.device();
|
||||||
|
|
||||||
|
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
let kernels = device.kernels();
|
||||||
|
let name = match a.dtype() {
|
||||||
|
DType::F32 => "softmax_f32",
|
||||||
|
DType::F16 => "softmax_f16",
|
||||||
|
DType::BF16 => "softmax_bf16",
|
||||||
|
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = layout.stride().len();
|
||||||
|
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
||||||
|
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||||
|
candle_metal_kernels::call_last_softmax(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
kernels,
|
||||||
|
name,
|
||||||
|
elem_count,
|
||||||
|
last_dim,
|
||||||
|
storage.buffer(),
|
||||||
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let device = device().unwrap();
|
let device = device().unwrap();
|
||||||
run_reduce(c, &device);
|
|
||||||
run_arg_reduce(c, &device);
|
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||||
|
run_softmax(c, &device, (lo, up));
|
||||||
|
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||||
|
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||||
|
|
||||||
|
run_reduce(c, &device, (lo, up));
|
||||||
|
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||||
|
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||||
|
|
||||||
|
run_arg_reduce(c, &device, (lo, up));
|
||||||
|
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||||
|
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||||
}
|
}
|
||||||
fn run_reduce(c: &mut Criterion, device: &Device) {
|
|
||||||
|
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||||
|
if !device.is_metal() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let b = 1;
|
||||||
|
let m = 2048;
|
||||||
|
let k = 2048;
|
||||||
|
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||||
|
|
||||||
|
let name = match T::DTYPE {
|
||||||
|
DType::F32 => "softmax_f32",
|
||||||
|
DType::F16 => "softmax_f16",
|
||||||
|
DType::BF16 => "softmax_bf16",
|
||||||
|
_ => "softmax",
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(bench_name(name));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
softmax(black_box(&a)).unwrap();
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||||
let b = 1;
|
let b = 1;
|
||||||
let m = 2048;
|
let m = 2048;
|
||||||
let k = 2048;
|
let k = 2048;
|
||||||
|
|
||||||
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * k * DType::F32.size_in_bytes();
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||||
|
|
||||||
let mut group = c.benchmark_group(bench_name("reduce"));
|
let name = match T::DTYPE {
|
||||||
|
DType::F32 => "reduce_f32",
|
||||||
|
DType::F16 => "reduce_f16",
|
||||||
|
DType::BF16 => "reduce_bf16",
|
||||||
|
_ => "reduce",
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(bench_name(name));
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
group.bench_function("iter", move |b| {
|
group.bench_function("iter", move |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
@ -39,16 +134,27 @@ fn run_reduce(c: &mut Criterion, device: &Device) {
|
|||||||
group.finish();
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_arg_reduce(c: &mut Criterion, device: &Device) {
|
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||||
|
c: &mut Criterion,
|
||||||
|
device: &Device,
|
||||||
|
(lo, up): (T, T),
|
||||||
|
) {
|
||||||
let b = 1;
|
let b = 1;
|
||||||
let m = 2048;
|
let m = 2048;
|
||||||
let k = 2048;
|
let k = 2048;
|
||||||
|
|
||||||
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * k * DType::F32.size_in_bytes();
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||||
|
|
||||||
let mut group = c.benchmark_group(bench_name("arg_reduce"));
|
let name = match T::DTYPE {
|
||||||
|
DType::F32 => "arg_reduce_f32",
|
||||||
|
DType::F16 => "arg_reduce_f16",
|
||||||
|
DType::BF16 => "arg_reduce_bf16",
|
||||||
|
_ => "reduce",
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(bench_name(name));
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
group.bench_function("iter", move |b| {
|
group.bench_function("iter", move |b| {
|
||||||
b.iter_custom(|iters| {
|
b.iter_custom(|iters| {
|
||||||
|
@ -545,8 +545,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
}
|
}
|
||||||
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
|
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
|
@ -74,7 +74,7 @@ METAL_FUNC void load_from_global(
|
|||||||
shared[tid] = op(shared[tid], src[strided_i]);
|
shared[tid] = op(shared[tid], src[strided_i]);
|
||||||
idx += block_dim;
|
idx += block_dim;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load strided elements from global memory into shared memory with indices.
|
// Load strided elements from global memory into shared memory with indices.
|
||||||
@ -107,7 +107,7 @@ METAL_FUNC void load_from_global(
|
|||||||
}
|
}
|
||||||
idx += block_dim;
|
idx += block_dim;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load contiguous elements from global memory into shared memory.
|
// Load contiguous elements from global memory into shared memory.
|
||||||
@ -131,7 +131,7 @@ METAL_FUNC void load_from_global(
|
|||||||
shared[tid] = op(shared[tid], src[idx]);
|
shared[tid] = op(shared[tid], src[idx]);
|
||||||
idx += block_dim;
|
idx += block_dim;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load contiguous elements from global memory into shared memory with indices.
|
// Load contiguous elements from global memory into shared memory with indices.
|
||||||
@ -162,15 +162,15 @@ METAL_FUNC void load_from_global(
|
|||||||
}
|
}
|
||||||
idx += block_dim;
|
idx += block_dim;
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define reduce_threadgroup(SIZE) \
|
#define reduce_threadgroup(SIZE) \
|
||||||
if (BLOCKSIZE >= SIZE) { \
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
if (block_dim >= SIZE) { \
|
if (block_dim >= SIZE) { \
|
||||||
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||||
@ -196,8 +196,8 @@ if (BLOCKSIZE >= SIZE) { \
|
|||||||
) { \
|
) { \
|
||||||
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
||||||
shared[tid] = shared[tid + SIZE / 2]; \
|
shared[tid] = shared[tid + SIZE / 2]; \
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
template<typename T, typename ArgReductionOp, uint BLOCKSIZE>
|
||||||
@ -221,8 +221,8 @@ METAL_FUNC void threadgroup_reduce(
|
|||||||
if (BLOCKSIZE >= SIZE) { \
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
if (tid < SIZE / 2 && block_dim >= SIZE) { \
|
if (tid < SIZE / 2 && block_dim >= SIZE) { \
|
||||||
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
shared[tid] = op(shared[tid], shared[tid + SIZE / 2]); \
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||||
@ -282,8 +282,8 @@ METAL_FUNC void block_reduce(
|
|||||||
|
|
||||||
if (tid < 32) {
|
if (tid < 32) {
|
||||||
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
|
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
}
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[dst_id] = shared[tid];
|
dst[dst_id] = shared[tid];
|
||||||
}
|
}
|
||||||
@ -358,8 +358,8 @@ if (BLOCKSIZE >= SIZE) { \
|
|||||||
) { \
|
) { \
|
||||||
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
shared_indices[tid] = shared_indices[tid + SIZE / 2]; \
|
||||||
shared[tid] = shared[tid + SIZE / 2]; \
|
shared[tid] = shared[tid + SIZE / 2]; \
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
template<
|
template<
|
||||||
@ -420,7 +420,7 @@ METAL_FUNC void arg_block_reduce(
|
|||||||
|
|
||||||
if (tid < 32) {
|
if (tid < 32) {
|
||||||
threadgroup_reduce<T, ArgReductionOp, BLOCKSIZE>(shared, shared_indices, tid, block_dim);
|
threadgroup_reduce<T, ArgReductionOp, BLOCKSIZE>(shared, shared_indices, tid, block_dim);
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
@ -491,71 +491,121 @@ kernel void NAME##_strided( \
|
|||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||||
|
|
||||||
#define SOFTMAX(NAME, T) \
|
|
||||||
kernel void NAME( \
|
#define softmax_max_block(SIZE) \
|
||||||
constant size_t &src_numel, \
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
constant size_t &el_to_sum_per_block, \
|
if (tid < SIZE / 2 && block_dim >= SIZE) { \
|
||||||
device const T *src, \
|
shared[tid] = max_op(shared[tid], shared[tid + SIZE / 2]); \
|
||||||
device T *dst, \
|
} \
|
||||||
\
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
uint id [[ thread_position_in_grid ]], \
|
}
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
#define softmax_acc_block(SIZE) \
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
if (BLOCKSIZE >= SIZE) { \
|
||||||
) { \
|
if (tid < SIZE / 2 && block_dim >= SIZE) { \
|
||||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
shared[tid] += shared[tid + SIZE / 2]; \
|
||||||
shared_memory[tid] = -INFINITY; \
|
} \
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
}
|
||||||
size_t idx = start_idx + tid; \
|
|
||||||
\
|
template<
|
||||||
\
|
typename T,
|
||||||
float tmp = -INFINITY; \
|
typename ACC,
|
||||||
while (idx < stop_idx) { \
|
uint BLOCKSIZE
|
||||||
tmp = MAX(tmp, float(src[idx])); \
|
>
|
||||||
idx += block_dim; \
|
METAL_FUNC void softmax(
|
||||||
} \
|
constant size_t &src_numel,
|
||||||
shared_memory[tid] = tmp; \
|
constant size_t &el_to_sum_per_block,
|
||||||
\
|
device const T *src,
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
device T *dst,
|
||||||
\
|
threadgroup ACC shared[BLOCKSIZE],
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s) { \
|
uint id [[ thread_position_in_grid ]],
|
||||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
} \
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
} \
|
) {
|
||||||
\
|
Max<ACC> max_op;
|
||||||
/* wait for shared_memory[0] to be filled */ \
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
shared[tid] = numeric_limits<ACC>::min();
|
||||||
\
|
ACC tmp = numeric_limits<ACC>::min();
|
||||||
float _max = shared_memory[0]; \
|
|
||||||
\
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
size_t idx = start_idx + tid;
|
||||||
shared_memory[tid] = 0; \
|
|
||||||
\
|
while (idx < stop_idx) {
|
||||||
idx = start_idx + tid; \
|
tmp = max_op(tmp, static_cast<ACC>(src[idx]));
|
||||||
while (idx < stop_idx) { \
|
idx += block_dim;
|
||||||
const float val = exp(float(src[idx]) - _max); \
|
}
|
||||||
dst[idx] = T(val); \
|
shared[tid] = tmp;
|
||||||
shared_memory[tid] += val; \
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
softmax_max_block(1024);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
softmax_max_block(512);
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
softmax_max_block(256);
|
||||||
if (tid < s) { \
|
softmax_max_block(128);
|
||||||
shared_memory[tid] += shared_memory[tid + s]; \
|
if (tid < 32) {
|
||||||
} \
|
threadgroup_reduce<ACC, Max<ACC>, BLOCKSIZE>(shared, tid, block_dim);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
}
|
||||||
} \
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
\
|
ACC _max = shared[0];
|
||||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
|
||||||
idx = start_idx + tid; \
|
// prevent tid 0 from overwriting _max before other threads have written
|
||||||
while (idx < stop_idx) { \
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
dst[idx] *= inv_acc; \
|
shared[tid] = 0;
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
const ACC val = exp(static_cast<ACC>(src[idx]) - _max);
|
||||||
|
dst[idx] = static_cast<T>(val);
|
||||||
|
shared[tid] += val;
|
||||||
|
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
softmax_acc_block(1024);
|
||||||
|
softmax_acc_block(512);
|
||||||
|
softmax_acc_block(256);
|
||||||
|
softmax_acc_block(128);
|
||||||
|
if (tid < 32) {
|
||||||
|
threadgroup_reduce<ACC, Sum<ACC>, BLOCKSIZE>(shared, tid, block_dim);
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
const T inv_acc = T(1.0/shared[0]);
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
dst[idx] *= inv_acc;
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#define SOFTMAX(NAME, T, ACC) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &src_numel, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
\
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
threadgroup ACC shared_memory[BLOCKSIZE]; \
|
||||||
|
softmax<T, ACC, BLOCKSIZE>( \
|
||||||
|
src_numel, \
|
||||||
|
el_to_sum_per_block, \
|
||||||
|
src, \
|
||||||
|
dst, \
|
||||||
|
shared_memory, \
|
||||||
|
id, \
|
||||||
|
tid, \
|
||||||
|
dst_id, \
|
||||||
|
block_dim); \
|
||||||
}
|
}
|
||||||
|
|
||||||
REDUCE(Sum, fast_sum_f32, float)
|
REDUCE(Sum, fast_sum_f32, float)
|
||||||
@ -588,8 +638,8 @@ ARG_REDUCE(ArgMax, fast_argmax_f16, half)
|
|||||||
ARG_REDUCE(ArgMax, fast_argmax_u32, uint)
|
ARG_REDUCE(ArgMax, fast_argmax_u32, uint)
|
||||||
ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t)
|
ARG_REDUCE(ArgMax, fast_argmax_u8, uint8_t)
|
||||||
|
|
||||||
SOFTMAX(softmax_f32, float)
|
SOFTMAX(softmax_f32, float, float)
|
||||||
SOFTMAX(softmax_f16, half)
|
SOFTMAX(softmax_f16, half, float)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 220
|
#if __METAL_VERSION__ >= 220
|
||||||
REDUCE(Sum, fast_sum_i64, int64_t)
|
REDUCE(Sum, fast_sum_i64, int64_t)
|
||||||
@ -611,5 +661,5 @@ REDUCE(Min, fast_min_bf16, bfloat)
|
|||||||
ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat)
|
ARG_REDUCE(ArgMin, fast_argmin_bf16, bfloat)
|
||||||
ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat)
|
ARG_REDUCE(ArgMax, fast_argmax_bf16, bfloat)
|
||||||
|
|
||||||
SOFTMAX(softmax_bf16, bfloat)
|
SOFTMAX(softmax_bf16, bfloat, float)
|
||||||
#endif
|
#endif
|
||||||
|
@ -529,7 +529,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("Error: {}", e);
|
println!("Error: {}", e);
|
||||||
panic!("damn!");
|
panic!("damn!");
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
read_to_vec(&output, out_length)
|
read_to_vec(&output, out_length)
|
||||||
@ -597,7 +597,6 @@ fn softmax() {
|
|||||||
}
|
}
|
||||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||||
let results = approx(results, 4);
|
let results = approx(results, 4);
|
||||||
println!("{results:?}");
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
|
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
|
||||||
n
|
n
|
||||||
|
Reference in New Issue
Block a user