mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Fixed all bugs. Improved code quality. Added tests.
This commit is contained in:
@ -3,9 +3,9 @@ mod benchmarks;
|
|||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
|
|
||||||
criterion_main!(
|
criterion_main!(
|
||||||
benchmarks::affine::benches,
|
//benchmarks::affine::benches,
|
||||||
benchmarks::matmul::benches,
|
//benchmarks::matmul::benches,
|
||||||
benchmarks::random::benches,
|
//benchmarks::random::benches,
|
||||||
benchmarks::reduce::benches,
|
benchmarks::reduce::benches,
|
||||||
benchmarks::where_cond::benches
|
//benchmarks::where_cond::benches
|
||||||
);
|
);
|
||||||
|
@ -61,13 +61,21 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||||||
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
run_softmax(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
||||||
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
run_softmax(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
||||||
|
|
||||||
run_reduce(c, &device, (lo, up));
|
run_reduce(c, &device, (lo, up), false);
|
||||||
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||||
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||||
|
|
||||||
run_arg_reduce(c, &device, (lo, up));
|
run_arg_reduce(c, &device, (lo, up), false);
|
||||||
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)));
|
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), false);
|
||||||
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)));
|
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), false);
|
||||||
|
|
||||||
|
run_reduce(c, &device, (lo, up), true);
|
||||||
|
run_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||||
|
run_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||||
|
|
||||||
|
run_arg_reduce(c, &device, (lo, up), true);
|
||||||
|
run_arg_reduce(c, &device, (f16::from_f32(lo), f16::from_f32(up)), true);
|
||||||
|
run_arg_reduce(c, &device, (bf16::from_f32(lo), bf16::from_f32(up)), true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,6 +97,7 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
|||||||
DType::BF16 => "softmax_bf16",
|
DType::BF16 => "softmax_bf16",
|
||||||
_ => "softmax",
|
_ => "softmax",
|
||||||
};
|
};
|
||||||
|
softmax(&a).unwrap();
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
@ -105,19 +114,49 @@ fn run_softmax<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (
|
|||||||
group.finish();
|
group.finish();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_reduce<T: candle_core::FloatDType>(c: &mut Criterion, device: &Device, (lo, up): (T, T)) {
|
fn run_reduce<T: candle_core::FloatDType>(
|
||||||
|
c: &mut Criterion,
|
||||||
|
device: &Device,
|
||||||
|
(lo, up): (T, T),
|
||||||
|
strided: bool,
|
||||||
|
) {
|
||||||
let b = 1;
|
let b = 1;
|
||||||
let m = 1024;
|
let m = 1024;
|
||||||
let k = 1024;
|
let k = 1024;
|
||||||
|
|
||||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
let a = if strided {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device)
|
||||||
|
.unwrap()
|
||||||
|
.transpose(0, 2)
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
||||||
|
|
||||||
let name = match T::DTYPE {
|
let name = match T::DTYPE {
|
||||||
DType::F32 => "reduce_f32",
|
DType::F32 => {
|
||||||
DType::F16 => "reduce_f16",
|
if strided {
|
||||||
DType::BF16 => "reduce_bf16",
|
"reduce_f32_strided"
|
||||||
|
} else {
|
||||||
|
"reduce_f32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
if strided {
|
||||||
|
"reduce_f16_strided"
|
||||||
|
} else {
|
||||||
|
"reduce_f16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
if strided {
|
||||||
|
"reduce_bf16_strided"
|
||||||
|
} else {
|
||||||
|
"reduce_bf16"
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => "reduce",
|
_ => "reduce",
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -140,20 +179,46 @@ fn run_arg_reduce<T: candle_core::FloatDType>(
|
|||||||
c: &mut Criterion,
|
c: &mut Criterion,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
(lo, up): (T, T),
|
(lo, up): (T, T),
|
||||||
|
strided: bool,
|
||||||
) {
|
) {
|
||||||
let b = 1;
|
let b = 1;
|
||||||
let m = 1024;
|
let m = 1024;
|
||||||
let k = 1024;
|
let k = 1024;
|
||||||
|
|
||||||
let a = Tensor::rand(lo, up, (b, m, k), &device).unwrap();
|
let a = if strided {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device)
|
||||||
|
.unwrap()
|
||||||
|
.transpose(0, 2)
|
||||||
|
.unwrap()
|
||||||
|
} else {
|
||||||
|
Tensor::rand(lo, up, (b, m, k), &device).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
let flops = b * m * k * T::DTYPE.size_in_bytes();
|
let flops = b * m * k * (DType::U32.size_in_bytes() + T::DTYPE.size_in_bytes());
|
||||||
|
|
||||||
let name = match T::DTYPE {
|
let name = match T::DTYPE {
|
||||||
DType::F32 => "arg_reduce_f32",
|
DType::F32 => {
|
||||||
DType::F16 => "arg_reduce_f16",
|
if strided {
|
||||||
DType::BF16 => "arg_reduce_bf16",
|
"arg_reduce_f32_strided"
|
||||||
_ => "reduce",
|
} else {
|
||||||
|
"arg_reduce_f32"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
if strided {
|
||||||
|
"arg_reduce_f16_strided"
|
||||||
|
} else {
|
||||||
|
"arg_reduce_f16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
if strided {
|
||||||
|
"arg_reduce_bf16_strided"
|
||||||
|
} else {
|
||||||
|
"arg_reduce_bf16"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => "unknown",
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name(name));
|
let mut group = c.benchmark_group(device.bench_name(name));
|
||||||
|
File diff suppressed because it is too large
Load Diff
346
candle-metal-kernels/src/reduce_old.metal
Normal file
346
candle-metal-kernels/src/reduce_old.metal
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
|
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
constant int THREADGROUP_SIZE = 2048;
|
||||||
|
|
||||||
|
|
||||||
|
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device uint *dst, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
\
|
||||||
|
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||||
|
\
|
||||||
|
shared_memory[tid] = MAXVALUE; \
|
||||||
|
shared_indices[tid] = 0xFFFFFFFF; \
|
||||||
|
bool notset = true; \
|
||||||
|
/* \
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
|
*/ \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
/* \
|
||||||
|
// TODO: Fast version for the contiguous case. \
|
||||||
|
*/ \
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
|
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||||
|
shared_memory[tid] = src[strided_i]; \
|
||||||
|
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||||
|
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||||
|
notset = false; \
|
||||||
|
} \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
\
|
||||||
|
/* \
|
||||||
|
// reduction in shared memory \
|
||||||
|
*/ \
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||||
|
shared_indices[tid] = shared_indices[tid + s]; \
|
||||||
|
shared_memory[tid] = shared_memory[tid + s]; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
if (tid == 0){ \
|
||||||
|
dst[dst_id] = shared_indices[0]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
|
||||||
|
#define ARGMAX(NAME, T, MINVALUE) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device uint *dst, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
\
|
||||||
|
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||||
|
\
|
||||||
|
shared_memory[tid] = MINVALUE; \
|
||||||
|
shared_indices[tid] = 0xFFFFFFFF; \
|
||||||
|
/* \
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
|
*/ \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
bool notset = true; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
/* \
|
||||||
|
// TODO: Fast version for the contiguous case. \
|
||||||
|
*/ \
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
|
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||||
|
shared_memory[tid] = src[strided_i]; \
|
||||||
|
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||||
|
notset = false; \
|
||||||
|
} \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
\
|
||||||
|
/* \
|
||||||
|
// reduction in shared memory \
|
||||||
|
*/ \
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||||
|
shared_indices[tid] = shared_indices[tid + s]; \
|
||||||
|
shared_memory[tid] = shared_memory[tid + s]; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
if (tid == 0){ \
|
||||||
|
dst[dst_id] = shared_indices[0]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#define REDUCE(FN, NAME, T, START) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
\
|
||||||
|
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
\
|
||||||
|
shared_memory[tid] = START; \
|
||||||
|
/* \
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
|
*/ \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
/* \
|
||||||
|
// TODO: Fast version for the contiguous case. \
|
||||||
|
*/ \
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
|
T x = shared_memory[tid]; \
|
||||||
|
T y = src[strided_i]; \
|
||||||
|
shared_memory[tid] = FN; \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
\
|
||||||
|
/* \
|
||||||
|
// reduction in shared memory \
|
||||||
|
*/ \
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s) { \
|
||||||
|
T x = shared_memory[tid]; \
|
||||||
|
T y = shared_memory[tid + s]; \
|
||||||
|
shared_memory[tid] = FN; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
dst[dst_id] = shared_memory[0]; \
|
||||||
|
} \
|
||||||
|
|
||||||
|
|
||||||
|
#define SOFTMAX(NAME, T) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &src_numel, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
\
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
shared_memory[tid] = -INFINITY; \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
\
|
||||||
|
\
|
||||||
|
float tmp = -INFINITY; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
tmp = MAX(tmp, float(src[idx])); \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
shared_memory[tid] = tmp; \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
\
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s) { \
|
||||||
|
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
/* wait for shared_memory[0] to be filled */ \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
\
|
||||||
|
float _max = shared_memory[0]; \
|
||||||
|
\
|
||||||
|
/* prevent tid=0 from overwriting _max before other threads have written it */ \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
shared_memory[tid] = 0; \
|
||||||
|
\
|
||||||
|
idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
const float val = exp(float(src[idx]) - _max); \
|
||||||
|
dst[idx] = T(val); \
|
||||||
|
shared_memory[tid] += val; \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s) { \
|
||||||
|
shared_memory[tid] += shared_memory[tid + s]; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
const T inv_acc = T(1.0/shared_memory[0]); \
|
||||||
|
idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
dst[idx] *= inv_acc; \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||||
|
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||||
|
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||||
|
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
||||||
|
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||||
|
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||||
|
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||||
|
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||||
|
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||||
|
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||||
|
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
||||||
|
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||||
|
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
||||||
|
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||||
|
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||||
|
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||||
|
ARGMIN(fast_argmin_u8_strided, uint8_t, 0xFF)
|
||||||
|
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||||
|
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||||
|
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||||
|
ARGMAX(fast_argmax_u8_strided, uint8_t, 0)
|
||||||
|
|
||||||
|
|
||||||
|
REDUCE(x + y, fast_sum_f32, float, 0)
|
||||||
|
REDUCE(x + y, fast_sum_u32, uint, 0)
|
||||||
|
REDUCE(x + y, fast_sum_f16, half, 0)
|
||||||
|
REDUCE(x + y, fast_sum_u8, uint8_t, 0)
|
||||||
|
REDUCE(x * y, fast_mul_f32, float, 1)
|
||||||
|
REDUCE(x * y, fast_mul_u32, uint, 1)
|
||||||
|
REDUCE(x * y, fast_mul_f16, half, 1)
|
||||||
|
REDUCE(MAX(x, y), fast_max_f32, float, -HUGE_VALF)
|
||||||
|
REDUCE(MAX(x, y), fast_max_u32, uint, 0)
|
||||||
|
REDUCE(MAX(x, y), fast_max_f16, half, -HUGE_VALH)
|
||||||
|
REDUCE(MAX(x, y), fast_max_u8, uint8_t, 0)
|
||||||
|
REDUCE(MIN(x, y), fast_min_f32, float, HUGE_VALF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_u32, uint, 0xFFFFFFFF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_f16, half, HUGE_VALH)
|
||||||
|
REDUCE(MIN(x, y), fast_min_u8, uint8_t, 0xFF)
|
||||||
|
ARGMIN(fast_argmin_f32, float, HUGE_VALF)
|
||||||
|
ARGMIN(fast_argmin_f16, half, HUGE_VALH)
|
||||||
|
ARGMIN(fast_argmin_u32, uint, 0xFFFFFFFF)
|
||||||
|
ARGMIN(fast_argmin_u8, uint8_t, 0xFF)
|
||||||
|
ARGMAX(fast_argmax_f32, float, -HUGE_VALF)
|
||||||
|
ARGMAX(fast_argmax_f16, half, -HUGE_VALH)
|
||||||
|
ARGMAX(fast_argmax_u32, uint, 0)
|
||||||
|
ARGMAX(fast_argmax_u8, uint8_t, 0)
|
||||||
|
|
||||||
|
SOFTMAX(softmax_f32, float)
|
||||||
|
SOFTMAX(softmax_f16, half)
|
||||||
|
|
||||||
|
#if __METAL_VERSION__ >= 220
|
||||||
|
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
||||||
|
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
||||||
|
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
||||||
|
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||||
|
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||||
|
|
||||||
|
|
||||||
|
REDUCE(x + y, fast_sum_i64, int64_t, 0)
|
||||||
|
REDUCE(MIN(x, y), fast_min_i64, int64_t, INT_MAX)
|
||||||
|
REDUCE(MAX(x, y), fast_max_i64, int64_t, INT_MIN)
|
||||||
|
ARGMIN(fast_argmin_i64, int64_t, INT_MAX)
|
||||||
|
ARGMAX(fast_argmax_i64, int64_t, INT_MIN)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
REDUCE(x + y, fast_sum_bf16_strided, bfloat, 0)
|
||||||
|
REDUCE(x * y, fast_mul_bf16_strided, bfloat, 1)
|
||||||
|
REDUCE(MAX(x, y), fast_max_bf16_strided, bfloat, -HUGE_VALBF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_bf16_strided, bfloat, HUGE_VALBF)
|
||||||
|
ARGMIN(fast_argmin_bf16_strided, bfloat, HUGE_VALBF)
|
||||||
|
ARGMAX(fast_argmax_bf16_strided, bfloat, -HUGE_VALBF)
|
||||||
|
|
||||||
|
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||||
|
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||||
|
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||||
|
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||||
|
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||||
|
|
||||||
|
SOFTMAX(softmax_bf16, bfloat)
|
||||||
|
#endif
|
@ -622,7 +622,7 @@ fn cos_f16() {
|
|||||||
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
assert_eq!(approx_f16(expected, 2), vec![0.54, -0.42, -0.99]);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
|
fn run_reduce<T, U: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<U> {
|
||||||
let device = device();
|
let device = device();
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
@ -630,10 +630,10 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
let output = device.new_buffer((out_length * core::mem::size_of::<U>()) as u64, options);
|
||||||
let dims = vec![v.len()];
|
let dims = vec![v.len()];
|
||||||
let strides = vec![1];
|
let strides = vec![1];
|
||||||
call_reduce_strided(
|
match call_reduce_strided(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
@ -644,8 +644,13 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
&input,
|
&input,
|
||||||
0,
|
0,
|
||||||
&output,
|
&output,
|
||||||
)
|
) {
|
||||||
.unwrap();
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
println!("Error: {}", e);
|
||||||
|
panic!();
|
||||||
|
}
|
||||||
|
}
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
@ -677,22 +682,114 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
|
|||||||
read_to_vec(&output, v.len())
|
read_to_vec(&output, v.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
const fn create_array<const N: usize>() -> [f32; N] {
|
||||||
fn reduce_sum() {
|
let mut array: [f32; N] = [0.0; N];
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let mut i = 1;
|
||||||
let out_length = 1;
|
while i <= N {
|
||||||
|
array[i - 1] = i as f32;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
array
|
||||||
|
}
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
const fn correct_sum<const N: usize, const D: usize>() -> [f32; D] {
|
||||||
assert_eq!(approx(results, 4), vec![21.0]);
|
let mut sum = 0;
|
||||||
|
let mut results: [f32; D] = [0.0; D];
|
||||||
|
let mut i = 1;
|
||||||
|
let mut j = 1;
|
||||||
|
while i <= N {
|
||||||
|
sum += i;
|
||||||
|
i += 1;
|
||||||
|
if i > j * N / D {
|
||||||
|
results[j - 1] = sum as f32;
|
||||||
|
j += 1;
|
||||||
|
sum = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
fn correct_argmax<const N: usize, const D: usize>(arr: [f32; N]) -> [u32; D] {
|
||||||
|
let mut max = 0.0;
|
||||||
|
let mut max_index: u32 = 0;
|
||||||
|
let mut results: [u32; D] = [0; D];
|
||||||
|
let mut i = 0;
|
||||||
|
let mut j = 1;
|
||||||
|
while i <= N {
|
||||||
|
if i >= (j * N / D) {
|
||||||
|
results[j - 1] = max_index;
|
||||||
|
max = 0.0;
|
||||||
|
max_index = 0;
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
if i == N {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if arr[i] > max {
|
||||||
|
max = arr[i];
|
||||||
|
max_index = i as u32;
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_sum_case<const N: usize, const D: usize>() {
|
||||||
|
let v = create_array::<N>();
|
||||||
|
let results = run_reduce(&v, D, "fast_sum_f32_strided");
|
||||||
|
assert_eq!(approx(results, 4), correct_sum::<N, D>());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reduce_argmax_case<const N: usize, const D: usize>() {
|
||||||
|
let v = create_array::<N>();
|
||||||
|
let results: Vec<u32> = run_reduce(&v, D, "fast_argmax_f32_strided");
|
||||||
|
assert_eq!(results, correct_argmax::<N, D>(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reduce_sum2() {
|
fn reduce_sum() {
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
reduce_sum_case::<6, 1>();
|
||||||
let out_length = 2;
|
reduce_sum_case::<10, 1>();
|
||||||
|
reduce_sum_case::<64, 1>();
|
||||||
|
reduce_sum_case::<128, 1>();
|
||||||
|
reduce_sum_case::<256, 1>();
|
||||||
|
reduce_sum_case::<512, 1>();
|
||||||
|
reduce_sum_case::<1024, 1>();
|
||||||
|
reduce_sum_case::<2048, 1>();
|
||||||
|
reduce_sum_case::<4096, 1>();
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
reduce_sum_case::<6, 2>();
|
||||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
reduce_sum_case::<10, 2>();
|
||||||
|
reduce_sum_case::<64, 2>();
|
||||||
|
reduce_sum_case::<128, 2>();
|
||||||
|
reduce_sum_case::<256, 2>();
|
||||||
|
reduce_sum_case::<512, 2>();
|
||||||
|
reduce_sum_case::<1024, 2>();
|
||||||
|
reduce_sum_case::<2048, 2>();
|
||||||
|
reduce_sum_case::<4096, 2>();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reduce_argmax() {
|
||||||
|
reduce_argmax_case::<6, 1>();
|
||||||
|
reduce_argmax_case::<10, 1>();
|
||||||
|
reduce_argmax_case::<64, 1>();
|
||||||
|
reduce_argmax_case::<128, 1>();
|
||||||
|
reduce_argmax_case::<256, 1>();
|
||||||
|
reduce_argmax_case::<512, 1>();
|
||||||
|
reduce_argmax_case::<1024, 1>();
|
||||||
|
reduce_argmax_case::<2048, 1>();
|
||||||
|
reduce_argmax_case::<4096, 1>();
|
||||||
|
|
||||||
|
reduce_argmax_case::<6, 2>();
|
||||||
|
reduce_argmax_case::<10, 2>();
|
||||||
|
reduce_argmax_case::<64, 2>();
|
||||||
|
reduce_argmax_case::<128, 2>();
|
||||||
|
reduce_argmax_case::<256, 2>();
|
||||||
|
reduce_argmax_case::<512, 2>();
|
||||||
|
reduce_argmax_case::<1024, 2>();
|
||||||
|
reduce_argmax_case::<2048, 2>();
|
||||||
|
reduce_argmax_case::<4096, 2>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
Reference in New Issue
Block a user