mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Compare commits
6 Commits
0.9.0-alph
...
ivarflakst
Author | SHA1 | Date | |
---|---|---|---|
8babfe0411 | |||
077e781f53 | |||
086b6ef6b6 | |||
2056866c25 | |||
1f4c54493e | |||
d5902840e0 |
@ -1,9 +1,11 @@
|
||||
mod benchmarks;
|
||||
|
||||
use criterion::criterion_main;
|
||||
|
||||
criterion_main!(
|
||||
benchmarks::affine::benches,
|
||||
benchmarks::matmul::benches,
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches
|
||||
//benchmarks::affine::benches,
|
||||
//benchmarks::matmul::benches,
|
||||
//benchmarks::random::benches,
|
||||
benchmarks::reduce::benches,
|
||||
//benchmarks::where_cond::benches
|
||||
);
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod reduce;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
239
candle-core/benches/benchmarks/reduce.rs
Normal file
239
candle-core/benches/benchmarks/reduce.rs
Normal file
@ -0,0 +1,239 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Storage, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use half::{bf16, f16};
|
||||
use std::ops::Deref;
|
||||
use std::time::Instant;
|
||||
|
||||
fn run_sum(a: &Tensor) {
|
||||
a.sum(2).unwrap();
|
||||
}
|
||||
fn run_arg_min(a: &Tensor) {
|
||||
a.argmin(2).unwrap();
|
||||
}
|
||||
|
||||
// TODO: Remove before merging. Softmax impls live in candle-nn, so this is a temporary workaround.
|
||||
fn softmax(a: &Tensor) -> candle_core::Result<()> {
|
||||
use candle_core::{backend::BackendStorage, DType};
|
||||
let (storage, layout) = a.storage_and_layout();
|
||||
|
||||
let device = a.device();
|
||||
|
||||
if let (Device::Metal(device), Storage::Metal(storage)) = (device, storage.deref()) {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let kernels = device.kernels();
|
||||
let name = match a.dtype() {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
dtype => candle_core::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
|
||||
candle_core::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
let (lo, up) = (-1000.0f32, 1000.0f32);
|
||||
for device in handler.devices {
|
||||
run_softmax(c, &device, (lo, up));
|
||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||
|
||||
run_reduce(c, &device, (lo, up), false);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), false);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||
|
||||
run_reduce(c, &device, (lo, up), true);
|
||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
|
||||
run_arg_reduce(c, &device, (lo, up), true);
|
||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||
}
|
||||
}
|
||||
|
||||
fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
||||
if !device.is_metal() {
|
||||
return;
|
||||
}
|
||||
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => "softmax_f32",
|
||||
DType::F16 => "softmax_f16",
|
||||
DType::BF16 => "softmax_bf16",
|
||||
_ => "softmax",
|
||||
};
|
||||
softmax(&a).unwrap();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
softmax(black_box(&a)).unwrap();
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"reduce_f32_strided"
|
||||
} else {
|
||||
"reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"reduce_f16_strided"
|
||||
} else {
|
||||
"reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"reduce_bf16_strided"
|
||||
} else {
|
||||
"reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "reduce",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_sum(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn run_arg_reduce<T: candle_core::FloatDType>(
|
||||
c: &mut Criterion,
|
||||
device: &Device,
|
||||
(lo, up): (T, T),
|
||||
strided: bool,
|
||||
) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let a = if strided {
|
||||
Tensor::rand(lo, up, (b, m, k), &device)
|
||||
.unwrap()
|
||||
.transpose(0, 2)
|
||||
.unwrap()
|
||||
} else {
|
||||
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||
};
|
||||
|
||||
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
|
||||
|
||||
let name = match T::DTYPE {
|
||||
DType::F32 => {
|
||||
if strided {
|
||||
"arg_reduce_f32_strided"
|
||||
} else {
|
||||
"arg_reduce_f32"
|
||||
}
|
||||
}
|
||||
DType::F16 => {
|
||||
if strided {
|
||||
"arg_reduce_f16_strided"
|
||||
} else {
|
||||
"arg_reduce_f16"
|
||||
}
|
||||
}
|
||||
DType::BF16 => {
|
||||
if strided {
|
||||
"arg_reduce_bf16_strided"
|
||||
} else {
|
||||
"arg_reduce_bf16"
|
||||
}
|
||||
}
|
||||
_ => "unknown",
|
||||
};
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run_arg_min(black_box(&a));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -489,6 +489,7 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device.clone();
|
||||
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
@ -502,13 +503,69 @@ impl BackendStorage for MetalStorage {
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
if layout.is_contiguous() {
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||
(k, dtype) => {
|
||||
crate::bail!("Metal contiguous reduce op {k:?} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.shape().elem_count(),
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
return Ok(Self::new(buffer, device, self.dtype));
|
||||
}
|
||||
|
||||
for &dim_idx in sum_dims.iter() {
|
||||
dims.push(src_dims[dim_idx]);
|
||||
stride.push(src_stride[dim_idx]);
|
||||
}
|
||||
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
@ -540,7 +597,7 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
(k, dtype) => crate::bail!("Metal strided reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
|
@ -623,7 +623,8 @@ pub fn call_reduce_strided(
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output
|
||||
output,
|
||||
out_length
|
||||
)
|
||||
);
|
||||
|
||||
|
@ -1,16 +1,18 @@
|
||||
#include <metal_stdlib>
|
||||
#include <metal_limits>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
// TODO: Load multiple values per thread to improve memory bandwidth utilization
|
||||
// static constant constexpr uint VALUES_PER_THREAD = 1;
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
constant const size_t &num_dims,
|
||||
constant const size_t *dims,
|
||||
constant const size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
@ -19,288 +21,637 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
template <typename V>
|
||||
struct Indexed {
|
||||
uint i;
|
||||
V val;
|
||||
typedef V type;
|
||||
|
||||
constexpr Indexed<V>() thread = default;
|
||||
constexpr Indexed<V>() threadgroup = default;
|
||||
constexpr Indexed<V>() device = default;
|
||||
constexpr Indexed<V>() constant = default;
|
||||
|
||||
constexpr Indexed<V>(uint _i, V _val) : i(_i), val(_val) {}
|
||||
|
||||
template <typename U, typename = typename enable_if<is_convertible_v<U, V>>::type>
|
||||
constexpr Indexed<V>(uint _i, U _val) : i(_i), val(static_cast<U>(_val)) {}
|
||||
|
||||
template <typename U>
|
||||
constexpr Indexed<V>(const thread Indexed<U> &iv): Indexed<V>(iv.i, iv.val) {}
|
||||
|
||||
template <typename U>
|
||||
constexpr Indexed<V>(const threadgroup Indexed<V> &iv): Indexed<V>(iv.i, iv.val) {}
|
||||
|
||||
Indexed<V> operator=(const thread Indexed<V> &iv) thread {
|
||||
this->i = iv.i;
|
||||
this->val = iv.val;
|
||||
return *this;
|
||||
}
|
||||
Indexed<V> operator=(const thread Indexed<V> &iv) threadgroup {
|
||||
this->i = iv.i;
|
||||
this->val = iv.val;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename V>
|
||||
constexpr METAL_FUNC bool operator<(Indexed<V> lhs, Indexed<V> rhs) {
|
||||
return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.i < rhs.i);
|
||||
}
|
||||
|
||||
template<typename V>
|
||||
constexpr METAL_FUNC bool operator>(Indexed<V> lhs, Indexed<V> rhs) {
|
||||
return lhs.val > rhs.val || (lhs.val == rhs.val && lhs.i > rhs.i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
struct _numeric_limits_impl<Indexed<T>> {
|
||||
static constexpr Indexed<T> lowest() {
|
||||
return Indexed<T>(0, numeric_limits<T>::lowest());
|
||||
}
|
||||
|
||||
static constexpr Indexed<T> max() {
|
||||
return Indexed<T>(0, numeric_limits<T>::max());
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
// Metal does not have simd_shuffle_down for bfloat16
|
||||
// TODO: Check if volatile threadgroup memory reduction is faster than simd_shuffle_down for bfloat
|
||||
bfloat simd_shuffle_down(bfloat value, ushort delta) {
|
||||
return static_cast<bfloat>(__metal_simd_shuffle_down(static_cast<float>(value), delta));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename V>
|
||||
Indexed<V> simd_shuffle_down(Indexed<V> iv, ushort delta) {
|
||||
return Indexed<V>(
|
||||
simd_shuffle_down(iv.i, delta),
|
||||
simd_shuffle_down(iv.val, delta)
|
||||
);
|
||||
}
|
||||
|
||||
#define impl_reduction_op_helper(name, op, init_val, __result_type__) \
|
||||
template<typename T, typename R = __result_type__> \
|
||||
struct name { \
|
||||
static constexpr T init() { \
|
||||
return init_val; \
|
||||
} \
|
||||
METAL_FUNC R operator()(T a, T b) { \
|
||||
return op; \
|
||||
} \
|
||||
METAL_FUNC R operator()(thread const T& a, thread const T& b) const { \
|
||||
return op; \
|
||||
} \
|
||||
METAL_FUNC R operator()(threadgroup const T& a, threadgroup const T& b) const { \
|
||||
return op; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define impl_reduction_op(name, op, init_val) \
|
||||
impl_reduction_op_helper(name, op, init_val, T);
|
||||
|
||||
#define impl_arg_reduction_op(name, op, init_val) \
|
||||
impl_reduction_op_helper(name, op, init_val, tuple<bool, Indexed<T>>);
|
||||
|
||||
impl_reduction_op(Sum, a + b, 0);
|
||||
impl_reduction_op(Mul, a * b, 1);
|
||||
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
|
||||
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::lowest());
|
||||
#undef impl_reduction_op
|
||||
|
||||
// These are used when loading elements from global memory into shared memory.
|
||||
// They let us use the same code for both indexed and non-indexed types.
|
||||
template<typename Op, typename T, typename U>
|
||||
METAL_FUNC T apply_operator(Op op, size_t _idx, T a, U b) {
|
||||
return op(a, static_cast<T>(b));
|
||||
}
|
||||
|
||||
template<typename Op, typename T, typename U>
|
||||
METAL_FUNC Indexed<T> apply_operator(Op op, size_t idx, Indexed<T> a, U b) {
|
||||
return op(a, Indexed<T>(idx, b));
|
||||
}
|
||||
|
||||
// Load elements from global memory into shared memory.
|
||||
// Handles both indexed and non-indexed types by using apply_operator.
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED = false
|
||||
>
|
||||
METAL_FUNC R load_from_global(
|
||||
R value,
|
||||
constant size_t &num_elements,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
const ushort offset,
|
||||
threadgroup R shared[BLOCKSIZE],
|
||||
const ushort tid
|
||||
) {
|
||||
ReductionOp op;
|
||||
|
||||
size_t stop_idx = offset + el_to_sum_per_block;
|
||||
size_t idx = offset + tid;
|
||||
|
||||
while (idx < stop_idx) {
|
||||
if (STRIDED) {
|
||||
idx = get_strided_index(idx, num_dims, dims, strides);
|
||||
}
|
||||
value = apply_operator(op, idx, value, src[idx]);
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
// Convenience function for when we don't need to sum over multiple dimensions.
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE
|
||||
>
|
||||
METAL_FUNC R load_from_global(
|
||||
R value,
|
||||
constant size_t &num_elements,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
const size_t offset,
|
||||
threadgroup R shared[BLOCKSIZE],
|
||||
const ushort tid
|
||||
) {
|
||||
return load_from_global<T, R, ReductionOp, BLOCKSIZE, false>(
|
||||
value,
|
||||
num_elements,
|
||||
// Dummy values for num_dims, dims, and strides
|
||||
num_elements,
|
||||
nullptr,
|
||||
nullptr,
|
||||
// end dummy values
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
}
|
||||
|
||||
// Since we are using simd_shuffle_down with a BLOCKSIZE guard we don't need any barriers.
|
||||
template<typename ReductionOp, ushort BLOCKSIZE, typename T>
|
||||
METAL_FUNC T simdgroup_reduce(T value) {
|
||||
ReductionOp op;
|
||||
if (BLOCKSIZE >= 32) value = op(value, simd_shuffle_down(value, 16));
|
||||
if (BLOCKSIZE >= 16) value = op(value, simd_shuffle_down(value, 8));
|
||||
if (BLOCKSIZE >= 8) value = op(value, simd_shuffle_down(value, 4));
|
||||
if (BLOCKSIZE >= 4) value = op(value, simd_shuffle_down(value, 2));
|
||||
if (BLOCKSIZE >= 2) value = op(value, simd_shuffle_down(value, 1));
|
||||
return value;
|
||||
}
|
||||
|
||||
template<
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
typename T
|
||||
>
|
||||
METAL_FUNC T threadgroup_reduce(
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]]
|
||||
) {
|
||||
ReductionOp op;
|
||||
|
||||
// Fully unrolled reduction loop from BLOCKSIZE down to 64.
|
||||
#pragma clang loop unroll(full)
|
||||
for (uint s = BLOCKSIZE / 2; s >= 64; s >>= 1) {
|
||||
if (tid < s) {
|
||||
shared[tid] = op(shared[tid], shared[tid + s]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (tid < 32) {
|
||||
// Last shared memory reduce can be done without tid < s check.
|
||||
if (BLOCKSIZE >= 64) {
|
||||
shared[tid] = op(shared[tid], shared[tid + 32]);
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
}
|
||||
// Remaining 32 threads can be reduced with simdgroup_reduce.
|
||||
shared[tid] = simdgroup_reduce<ReductionOp, BLOCKSIZE>(shared[tid]);
|
||||
}
|
||||
|
||||
return shared[tid];
|
||||
}
|
||||
|
||||
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||
template<
|
||||
typename T,
|
||||
typename R,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED = false
|
||||
>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device R *dst,
|
||||
constant size_t &num_elements,
|
||||
threadgroup T shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to correct value for reduction operation
|
||||
shared[tid] = ReductionOp::init();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
ushort offset = dst_id * el_to_sum_per_block;
|
||||
R initial = ReductionOp::init();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, R, ReductionOp, BLOCKSIZE, STRIDED>(
|
||||
initial,
|
||||
num_elements,
|
||||
num_dims,
|
||||
dims,
|
||||
strides,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Complete reduction
|
||||
R value = threadgroup_reduce<ReductionOp, BLOCKSIZE>(shared, tid);
|
||||
|
||||
if (tid == 0) dst[dst_id] = value;
|
||||
}
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
#define reduce_case(OP, T, R, N) \
|
||||
case N: { \
|
||||
threadgroup R shared[N]; \
|
||||
reduce<T, R, OP<R>, N, STRIDED>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
num_elements, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
#define impl_reduce(OP, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
constant size_t *dims = {}; \
|
||||
constant size_t *strides = {}; \
|
||||
const bool STRIDED = false; \
|
||||
switch (block_dim) { \
|
||||
reduce_case(OP, T, T, 2048); \
|
||||
reduce_case(OP, T, T, 1024); \
|
||||
reduce_case(OP, T, T, 512); \
|
||||
reduce_case(OP, T, T, 256); \
|
||||
reduce_case(OP, T, T, 128); \
|
||||
reduce_case(OP, T, T, 64); \
|
||||
reduce_case(OP, T, T, 32); \
|
||||
reduce_case(OP, T, T, 16); \
|
||||
reduce_case(OP, T, T, 8); \
|
||||
reduce_case(OP, T, T, 4); \
|
||||
reduce_case(OP, T, T, 2); \
|
||||
reduce_case(OP, T, T, 1); \
|
||||
} \
|
||||
} \
|
||||
kernel void NAME##_strided( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
const bool STRIDED = true; \
|
||||
switch (block_dim) { \
|
||||
reduce_case(OP, T, T, 2048); \
|
||||
reduce_case(OP, T, T, 1024); \
|
||||
reduce_case(OP, T, T, 512); \
|
||||
reduce_case(OP, T, T, 256); \
|
||||
reduce_case(OP, T, T, 128); \
|
||||
reduce_case(OP, T, T, 64); \
|
||||
reduce_case(OP, T, T, 32); \
|
||||
reduce_case(OP, T, T, 16); \
|
||||
reduce_case(OP, T, T, 8); \
|
||||
reduce_case(OP, T, T, 4); \
|
||||
reduce_case(OP, T, T, 2); \
|
||||
reduce_case(OP, T, T, 1); \
|
||||
} \
|
||||
}
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename ReductionOp,
|
||||
ushort BLOCKSIZE,
|
||||
bool STRIDED
|
||||
>
|
||||
METAL_FUNC void reduce(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
device const T *src,
|
||||
device uint *dst,
|
||||
constant size_t &num_elements,
|
||||
threadgroup Indexed<T> shared[BLOCKSIZE],
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to correct value for reduction operation
|
||||
shared[tid] = ReductionOp::init();
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
ushort offset = dst_id * el_to_sum_per_block;
|
||||
Indexed<T> initial = ReductionOp::init();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, Indexed<T>, ReductionOp, BLOCKSIZE, STRIDED>(
|
||||
initial,
|
||||
num_elements,
|
||||
num_dims,
|
||||
dims,
|
||||
strides,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
// Complete reduction
|
||||
Indexed<T> value = threadgroup_reduce<ReductionOp, BLOCKSIZE, Indexed<T>>(shared, tid);
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
// Return index of reduce result
|
||||
if (tid == 0) dst[dst_id] = value.i;
|
||||
}
|
||||
|
||||
#define arg_reduce_case(OP, T, N) \
|
||||
case N: { \
|
||||
threadgroup Indexed<T> shared[N]; \
|
||||
reduce<T, OP<Indexed<T>>, N, STRIDED>( \
|
||||
num_dims, \
|
||||
dims, \
|
||||
strides, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
num_elements, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define impl_arg_reduce(OP, NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
constant size_t *dims = {}; \
|
||||
constant size_t *strides = {}; \
|
||||
const bool STRIDED = false; \
|
||||
switch (block_dim) { \
|
||||
arg_reduce_case(OP, T, 2048); \
|
||||
arg_reduce_case(OP, T, 1024); \
|
||||
arg_reduce_case(OP, T, 512); \
|
||||
arg_reduce_case(OP, T, 256); \
|
||||
arg_reduce_case(OP, T, 128); \
|
||||
arg_reduce_case(OP, T, 64); \
|
||||
arg_reduce_case(OP, T, 32); \
|
||||
arg_reduce_case(OP, T, 16); \
|
||||
arg_reduce_case(OP, T, 8); \
|
||||
arg_reduce_case(OP, T, 4); \
|
||||
arg_reduce_case(OP, T, 2); \
|
||||
arg_reduce_case(OP, T, 1); \
|
||||
} \
|
||||
} \
|
||||
kernel void NAME##_strided( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
constant size_t &num_elements, \
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
const bool STRIDED = true; \
|
||||
switch (block_dim) { \
|
||||
arg_reduce_case(OP, T, 2048); \
|
||||
arg_reduce_case(OP, T, 1024); \
|
||||
arg_reduce_case(OP, T, 512); \
|
||||
arg_reduce_case(OP, T, 256); \
|
||||
arg_reduce_case(OP, T, 128); \
|
||||
arg_reduce_case(OP, T, 64); \
|
||||
arg_reduce_case(OP, T, 32); \
|
||||
arg_reduce_case(OP, T, 16); \
|
||||
arg_reduce_case(OP, T, 8); \
|
||||
arg_reduce_case(OP, T, 4); \
|
||||
arg_reduce_case(OP, T, 2); \
|
||||
arg_reduce_case(OP, T, 1); \
|
||||
} \
|
||||
}
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename ACC = float,
|
||||
ushort BLOCKSIZE
|
||||
>
|
||||
METAL_FUNC void softmax(
|
||||
constant size_t &src_numel,
|
||||
constant size_t &el_to_sum_per_block,
|
||||
const device T *src,
|
||||
device T *dst,
|
||||
threadgroup ACC shared[BLOCKSIZE],
|
||||
|
||||
ushort tid [[ thread_index_in_threadgroup ]],
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]]
|
||||
) {
|
||||
// Initialize shared memory for current thread to lowest value
|
||||
shared[tid] = numeric_limits<ACC>::lowest();
|
||||
|
||||
// Calcluate offset for the threadgroup of current thread
|
||||
size_t offset = dst_id * el_to_sum_per_block;
|
||||
ACC initial = numeric_limits<ACC>::lowest();
|
||||
// Load with reduction from global memory into shared memory
|
||||
shared[tid] = load_from_global<T, ACC, Max<ACC>, BLOCKSIZE>(
|
||||
initial,
|
||||
src_numel,
|
||||
el_to_sum_per_block,
|
||||
src,
|
||||
offset,
|
||||
shared,
|
||||
tid
|
||||
);
|
||||
// Threadgroup barrier is needed to ensure that all threads have written to shared memory
|
||||
// Memory space is not shared between threadgroups so we can use the mem_none flag for all threadgroup barriers.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
// Reduce shared memory to find max value
|
||||
threadgroup_reduce<Max<ACC>, BLOCKSIZE>(shared, tid);
|
||||
ACC max_result = shared[0];
|
||||
|
||||
// Ensure all threads have max_result = shared[0] before we set shared[0] = 0.
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
shared[tid] = 0;
|
||||
|
||||
// Calculate softmax values
|
||||
size_t stop_idx = min(offset + el_to_sum_per_block, src_numel);
|
||||
size_t idx = offset + tid;
|
||||
while (idx < stop_idx) {
|
||||
const ACC val = exp(ACC(src[idx]) - max_result);
|
||||
dst[idx] = T(val);
|
||||
shared[tid] += val;
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
threadgroup_reduce<Sum<ACC>, BLOCKSIZE>(shared, tid);
|
||||
threadgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
const T inv_acc = T(1.0/shared[0]);
|
||||
idx = offset + tid;
|
||||
while (idx < stop_idx) {
|
||||
dst[idx] *= inv_acc;
|
||||
idx += BLOCKSIZE;
|
||||
}
|
||||
}
|
||||
|
||||
#define softmax_case(T, ACC, N) \
|
||||
case N: { \
|
||||
threadgroup ACC shared[N]; \
|
||||
softmax<T, ACC, N>( \
|
||||
src_numel, \
|
||||
el_to_sum_per_block, \
|
||||
src, \
|
||||
dst, \
|
||||
shared, \
|
||||
tid, \
|
||||
dst_id); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define impl_softmax(NAME, T, ACC) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
ushort tid [[ thread_index_in_threadgroup ]], \
|
||||
ushort dst_id [[ threadgroup_position_in_grid ]], \
|
||||
ushort block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
switch (block_dim) { \
|
||||
softmax_case(T, ACC, 2048); \
|
||||
softmax_case(T, ACC, 1024); \
|
||||
softmax_case(T, ACC, 512); \
|
||||
softmax_case(T, ACC, 256); \
|
||||
softmax_case(T, ACC, 128); \
|
||||
softmax_case(T, ACC, 64); \
|
||||
softmax_case(T, ACC, 32); \
|
||||
softmax_case(T, ACC, 16); \
|
||||
softmax_case(T, ACC, 8); \
|
||||
softmax_case(T, ACC, 4); \
|
||||
softmax_case(T, ACC, 2); \
|
||||
softmax_case(T, ACC, 1); \
|
||||
} \
|
||||
}
|
||||
|
||||
impl_reduce(Sum, fast_sum_f32, float)
|
||||
impl_reduce(Sum, fast_sum_u32, uint)
|
||||
impl_reduce(Sum, fast_sum_f16, half)
|
||||
impl_reduce(Sum, fast_sum_u8, uint8_t)
|
||||
|
||||
impl_reduce(Mul, fast_mul_f32, float)
|
||||
impl_reduce(Mul, fast_mul_u32, uint)
|
||||
impl_reduce(Mul, fast_mul_f16, half)
|
||||
impl_reduce(Mul, fast_mul_u8, uint8_t)
|
||||
|
||||
impl_reduce(Max, fast_max_f32, float)
|
||||
impl_reduce(Max, fast_max_u32, uint)
|
||||
impl_reduce(Max, fast_max_f16, half)
|
||||
impl_reduce(Max, fast_max_u8, uint8_t)
|
||||
|
||||
impl_reduce(Min, fast_min_f32, float)
|
||||
impl_reduce(Min, fast_min_u32, uint)
|
||||
impl_reduce(Min, fast_min_f16, half)
|
||||
impl_reduce(Min, fast_min_u8, uint8_t)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_f32, float)
|
||||
impl_arg_reduce(Min, fast_argmin_f16, half)
|
||||
impl_arg_reduce(Min, fast_argmin_u32, uint)
|
||||
impl_arg_reduce(Min, fast_argmin_u8, uint8_t)
|
||||
|
||||
impl_arg_reduce(Max, fast_argmax_f32, float)
|
||||
impl_arg_reduce(Max, fast_argmax_f16, half)
|
||||
impl_arg_reduce(Max, fast_argmax_u32, uint)
|
||||
impl_arg_reduce(Max, fast_argmax_u8, uint8_t)
|
||||
|
||||
impl_softmax(softmax_f32, float, float)
|
||||
impl_softmax(softmax_f16, half, float)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
impl_reduce(Sum, fast_sum_i64, int64_t)
|
||||
impl_reduce(Mul, fast_mul_i64, int64_t)
|
||||
impl_reduce(Min, fast_min_i64, int64_t)
|
||||
impl_reduce(Max, fast_max_i64, int64_t)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_i64, int64_t)
|
||||
impl_arg_reduce(Max, fast_argmax_i64, int64_t)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
impl_reduce(Sum, fast_sum_bf16, bfloat)
|
||||
impl_reduce(Mul, fast_mul_bf16, bfloat)
|
||||
impl_reduce(Max, fast_max_bf16, bfloat)
|
||||
impl_reduce(Min, fast_min_bf16, bfloat)
|
||||
|
||||
impl_arg_reduce(Min, fast_argmin_bf16, bfloat)
|
||||
impl_arg_reduce(Max, fast_argmax_bf16, bfloat)
|
||||
|
||||
impl_softmax(softmax_bf16, bfloat, float)
|
||||
#endif
|
||||
|
346
candle-metal-kernels/src/reduce_old.metal
Normal file
346
candle-metal-kernels/src/reduce_old.metal
Normal file
@ -0,0 +1,346 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = shared_memory[tid + s]; \
|
||||
shared_memory[tid] = FN; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
dst[dst_id] = shared_memory[0]; \
|
||||
} \
|
||||
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
\
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||
shared_memory[tid] = -INFINITY; \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t idx = start_idx + tid; \
|
||||
\
|
||||
\
|
||||
float tmp = -INFINITY; \
|
||||
while (idx < stop_idx) { \
|
||||
tmp = MAX(tmp, float(src[idx])); \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
shared_memory[tid] = tmp; \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
/* wait for shared_memory[0] to be filled */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
\
|
||||
float _max = shared_memory[0]; \
|
||||
\
|
||||
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
shared_memory[tid] = 0; \
|
||||
\
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
const float val = exp(float(src[idx]) - _max); \
|
||||
dst[idx] = T(val); \
|
||||
shared_memory[tid] += val; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s) { \
|
||||
shared_memory[tid] += shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||
} \
|
||||
\
|
||||
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||
idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
dst[idx] *= inv_acc; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
} \
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_f32, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16, half, 0)
|
||||
REDUCE(x + y, fast_sum_u8, uint8_t, 0)
|
||||
REDUCE(x * y, fast_mul_f32, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH)
|
||||
REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF)
|
||||
ARGMIN(fast_argmin_f32, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF)
|
||||
ARGMIN(fast_argmin_u8, uint8_t, 0xFF)
|
||||
ARGMAX(fast_argmax_f32, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32, uint, 0)
|
||||
ARGMAX(fast_argmax_u8, uint8_t, 0)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_i64, int64_t, 0)
|
||||
REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX)
|
||||
REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN)
|
||||
ARGMIN(fast_argmin_i64, int64_t, INT_MAX)
|
||||
ARGMAX(fast_argmax_i64, int64_t, INT_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF)
|
||||
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
@ -622,7 +622,7 @@ fn cos_f16() {
|
||||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||
}
|
||||
|
||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
||||
fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
@ -630,10 +630,10 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
let input = new_buffer(&device, v);
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
|
||||
let dims = vec![v.len()];
|
||||
let strides = vec![1];
|
||||
call_reduce_strided(
|
||||
match call_reduce_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -644,8 +644,13 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
&input,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
println!("Error: {}", e);
|
||||
panic!();
|
||||
}
|
||||
}
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
@ -677,22 +682,114 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
||||
read_to_vec(&output, v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
const fn create_array<const N: usize>() -> [f32; N] {
|
||||
let mut array: [f32; N] = [0.0; N];
|
||||
let mut i = 1;
|
||||
while i <= N {
|
||||
array[i - 1] = i as f32;
|
||||
i += 1;
|
||||
}
|
||||
array
|
||||
}
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
|
||||
let mut sum = 0;
|
||||
let mut results: [f32; D] = [0.0; D];
|
||||
let mut i = 1;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
sum += i;
|
||||
i += 1;
|
||||
if i > j * N / D {
|
||||
results[j - 1] = sum as f32;
|
||||
j += 1;
|
||||
sum = 0;
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
|
||||
let mut max = 0.0;
|
||||
let mut max_index: u32 = 0;
|
||||
let mut results: [u32; D] = [0; D];
|
||||
let mut i = 0;
|
||||
let mut j = 1;
|
||||
while i <= N {
|
||||
if i >= (j * N / D) {
|
||||
results[j - 1] = max_index;
|
||||
max = 0.0;
|
||||
max_index = 0;
|
||||
j += 1;
|
||||
}
|
||||
if i == N {
|
||||
break;
|
||||
}
|
||||
if arr[i] > max {
|
||||
max = arr[i];
|
||||
max_index = i as u32;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
fn reduce_sum_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results = run_reduce(&v, D, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), correct_sum::<N, D>());
|
||||
}
|
||||
|
||||
fn reduce_argmax_case<const N: usize, const D: usize>() {
|
||||
let v = create_array::<N>();
|
||||
let results: Vec<u32> = run_reduce(&v, D, "fast_argmax_f32_strided");
|
||||
assert_eq!(results, correct_argmax::<N, D>(v));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
fn reduce_sum() {
|
||||
reduce_sum_case::<6, 1>();
|
||||
reduce_sum_case::<10, 1>();
|
||||
reduce_sum_case::<64, 1>();
|
||||
reduce_sum_case::<128, 1>();
|
||||
reduce_sum_case::<256, 1>();
|
||||
reduce_sum_case::<512, 1>();
|
||||
reduce_sum_case::<1024, 1>();
|
||||
reduce_sum_case::<2048, 1>();
|
||||
reduce_sum_case::<4096, 1>();
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
reduce_sum_case::<6, 2>();
|
||||
reduce_sum_case::<10, 2>();
|
||||
reduce_sum_case::<64, 2>();
|
||||
reduce_sum_case::<128, 2>();
|
||||
reduce_sum_case::<256, 2>();
|
||||
reduce_sum_case::<512, 2>();
|
||||
reduce_sum_case::<1024, 2>();
|
||||
reduce_sum_case::<2048, 2>();
|
||||
reduce_sum_case::<4096, 2>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduce_argmax() {
|
||||
reduce_argmax_case::<6, 1>();
|
||||
reduce_argmax_case::<10, 1>();
|
||||
reduce_argmax_case::<64, 1>();
|
||||
reduce_argmax_case::<128, 1>();
|
||||
reduce_argmax_case::<256, 1>();
|
||||
reduce_argmax_case::<512, 1>();
|
||||
reduce_argmax_case::<1024, 1>();
|
||||
reduce_argmax_case::<2048, 1>();
|
||||
reduce_argmax_case::<4096, 1>();
|
||||
|
||||
reduce_argmax_case::<6, 2>();
|
||||
reduce_argmax_case::<10, 2>();
|
||||
reduce_argmax_case::<64, 2>();
|
||||
reduce_argmax_case::<128, 2>();
|
||||
reduce_argmax_case::<256, 2>();
|
||||
reduce_argmax_case::<512, 2>();
|
||||
reduce_argmax_case::<1024, 2>();
|
||||
reduce_argmax_case::<2048, 2>();
|
||||
reduce_argmax_case::<4096, 2>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
Reference in New Issue
Block a user