Improve reduce perf and add contiguous impl

This commit is contained in:
Ivar Flakstad
2024-01-21 17:32:21 +01:00
parent 88945f2c22
commit d5902840e0
7 changed files with 409 additions and 96 deletions

View File

@ -1,4 +1,4 @@
mod benchmarks; mod benchmarks;
use criterion::criterion_main; use criterion::criterion_main;
criterion_main!(benchmarks::matmul::benches); criterion_main!(benchmarks::reduce::benches);

View File

@ -1,4 +1,5 @@
pub(crate) mod matmul; pub(crate) mod matmul;
pub(crate) mod reduce;
use candle_core::{Device, Result}; use candle_core::{Device, Result};

View File

@ -0,0 +1,36 @@
use candle_core::{DType, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
use crate::benchmarks::{bench_name, device, BenchDevice};
fn run(a: &Tensor) {
a.sum(2).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let b = 1;
let m = 2048;
let k = 2048;
let device = device().unwrap();
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
let flops = b * m * k * DType::F32.size_in_bytes();
let mut group = c.benchmark_group(bench_name("reduce"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&a));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
criterion_group!(benches, criterion_benchmark);

View File

@ -491,6 +491,7 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device.clone(); let device = self.device.clone();
let src_stride = layout.stride(); let src_stride = layout.stride();
let src_dims = layout.shape().dims(); let src_dims = layout.shape().dims();
// Source dims and strides with the sum dims at the end. // Source dims and strides with the sum dims at the end.
@ -504,13 +505,72 @@ impl BackendStorage for MetalStorage {
stride.push(src_stride[dim_idx]); stride.push(src_stride[dim_idx]);
} }
} }
if layout.is_contiguous() {
let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
//(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
//(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
//(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
//(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
//(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
//(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
//(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
//(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
//(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
//(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
//(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
//(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
//(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
_ => ("fall back to strided impl", false, false)
};
if name != "fall back to strided impl" {
if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous(
&device.device,
&command_buffer,
&device.kernels,
name,
layout.shape().elem_count(),
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
return Ok(Self::new(buffer, device, self.dtype));
}
}
for &dim_idx in sum_dims.iter() { for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]); dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]); stride.push(src_stride[dim_idx]);
} }
// The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) { let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),

View File

@ -568,7 +568,6 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -597,7 +596,6 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -619,7 +617,6 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -630,7 +627,8 @@ pub fn call_reduce_strided(
strides, strides,
elements_to_sum, elements_to_sum,
(input, input_offset), (input, input_offset),
output output,
out_length
) )
); );
@ -655,7 +653,6 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }

View File

@ -1,16 +1,15 @@
#include <metal_stdlib> #include <metal_stdlib>
#include <metal_limits>
using namespace metal; using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(
uint idx, uint idx,
constant size_t &num_dims, constant const size_t &num_dims,
constant size_t *dims, constant const size_t *dims,
constant size_t *strides constant const size_t *strides
) { ) {
uint strided_i = 0; uint strided_i = 0;
#pragma clang loop unroll(full)
for (uint d = 0; d < num_dims; d++) { for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d; uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
@ -19,8 +18,270 @@ METAL_FUNC uint get_strided_index(
return strided_i; return strided_i;
} }
constant int THREADGROUP_SIZE = 2048; #define impl_reduction_op(name, op, init_val) \
template<typename T> \
struct name { \
\
static constexpr constant T init = init_val; \
\
METAL_FUNC T operator()(thread const T &a, thread const T &b) const { \
return op; \
} \
\
METAL_FUNC T operator()(threadgroup const T &a, threadgroup const T &b) const { \
return op; \
} \
\
METAL_FUNC T operator()(device const T &a, device const T &b) const { \
return op; \
} \
\
METAL_FUNC T operator()(T a, T b) { \
return op; \
} \
} \
impl_reduction_op(Sum, a + b, 0);
impl_reduction_op(Mul, a * b, 1);
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::min());
#undef impl_reduction_op
static constant constexpr int THREADGROUP_SIZE = 2048;
template<typename T, typename ReductionOp, uint BLOCKSIZE>
METAL_FUNC void load_from_global(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
threadgroup T shared[BLOCKSIZE],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ReductionOp op;
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
shared[tid] = op(shared[tid], src[strided_i]);
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
}
template<typename T, typename ReductionOp, uint BLOCKSIZE>
METAL_FUNC void load_from_global(
constant size_t &num_dims,
constant size_t *dims,
constant size_t &el_to_sum_per_block,
device const T *src,
threadgroup T shared[BLOCKSIZE],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ReductionOp op;
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = start_idx + el_to_sum_per_block;
size_t idx = start_idx + tid;
while (idx < stop_idx) {
shared[tid] = op(shared[tid], src[idx]);
idx += block_dim;
}
threadgroup_barrier(mem_flags::mem_none);
}
template<typename T, typename ReductionOp, uint BLOCKSIZE>
METAL_FUNC void threadgroup_reduce(
threadgroup T shared[BLOCKSIZE],
uint tid [[thread_index_in_threadgroup]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ReductionOp op;
if (BLOCKSIZE >= 64) {
if (block_dim >= 64) {
shared[tid] = op(shared[tid], shared[tid + 32]);
}
}
if (BLOCKSIZE >= 32) {
if (block_dim >= 32) {
shared[tid] = op(shared[tid], shared[tid + 16]);
}
}
if (BLOCKSIZE >= 16) {
if (block_dim >= 16) {
shared[tid] = op(shared[tid], shared[tid + 8]);
}
}
if (BLOCKSIZE >= 8) {
if (block_dim >= 8) {
shared[tid] = op(shared[tid], shared[tid + 4]);
}
}
if (BLOCKSIZE >= 4) {
if (block_dim >= 4) {
shared[tid] = op(shared[tid], shared[tid + 2]);
}
}
if (BLOCKSIZE >= 2) {
if (block_dim >= 2) {
shared[tid] = op(shared[tid], shared[tid + 1]);
}
}
}
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
template<
typename T,
typename ReductionOp,
uint BLOCKSIZE,
bool STRIDED
>
METAL_FUNC void block_reduce(
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides,
constant size_t &el_to_sum_per_block,
device const T *src,
device T *dst,
constant uint &num_elements,
threadgroup T shared[BLOCKSIZE],
uint id [[ thread_position_in_grid ]],
uint tid [[ thread_index_in_threadgroup ]],
uint dst_id [[ threadgroup_position_in_grid ]],
uint block_dim [[ threads_per_threadgroup ]]
) {
ReductionOp op;
shared[tid] = ReductionOp::init;
if (STRIDED) {
load_from_global<T, ReductionOp, BLOCKSIZE>(
num_dims,
dims,
strides,
el_to_sum_per_block,
src,
shared,
tid,
dst_id,
block_dim
);
} else {
load_from_global<T, ReductionOp, BLOCKSIZE>(
num_dims,
dims,
el_to_sum_per_block,
src,
shared,
tid,
dst_id,
block_dim
);
}
if (BLOCKSIZE >= 1024) {
if (tid < 512 && block_dim >= 1024) {
shared[tid] = op(shared[tid], shared[tid + 512]);
threadgroup_barrier(mem_flags::mem_none);
}
}
if (BLOCKSIZE >= 512) {
if (tid < 256 && block_dim >= 512) {
shared[tid] = op(shared[tid], shared[tid + 256]);
threadgroup_barrier(mem_flags::mem_none);
}
}
if (BLOCKSIZE >= 256) {
if (tid < 128 && block_dim >= 256) {
shared[tid] = op(shared[tid], shared[tid + 128]);
threadgroup_barrier(mem_flags::mem_none);
}
}
if (BLOCKSIZE >= 128) {
if (tid < 64 && block_dim >= 128) {
shared[tid] = op(shared[tid], shared[tid + 64]);
threadgroup_barrier(mem_flags::mem_none);
}
}
if (tid < 32) {
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
threadgroup_barrier(mem_flags::mem_none);
}
if (tid == 0) {
dst[dst_id] = shared[tid];
}
}
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
static constant constexpr int BLOCKSIZE = 2048;
#define REDUCE(OP, NAME, T) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
constant uint &num_elements, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup T shared[BLOCKSIZE]; \
block_reduce<T, OP<T>, BLOCKSIZE, false>( \
num_dims, \
dims, \
strides, \
el_to_sum_per_block, \
src, \
dst, \
num_elements, \
shared, \
id, \
tid, \
dst_id, \
block_dim); \
} \
kernel void NAME##_strided( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
constant uint &num_elements, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup T shared[BLOCKSIZE]; \
block_reduce<T, OP<T>, BLOCKSIZE, false>( \
num_dims, \
dims, \
strides, \
el_to_sum_per_block, \
src, \
dst, \
num_elements, \
shared, \
id, \
tid, \
dst_id, \
block_dim); \
} \
#define ARGMIN(NAME, T, MAXVALUE) \ #define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \ kernel void NAME( \
@ -140,59 +401,6 @@ kernel void NAME( \
} \ } \
} \ } \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = START; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
T x = shared_memory[tid]; \
T y = shared_memory[tid + s]; \
shared_memory[tid] = FN; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
dst[dst_id] = shared_memory[0]; \
} \
#define SOFTMAX(NAME, T) \ #define SOFTMAX(NAME, T) \
kernel void NAME( \ kernel void NAME( \
constant size_t &src_numel, \ constant size_t &src_numel, \
@ -258,23 +466,24 @@ kernel void NAME(
dst[idx] *= inv_acc; \ dst[idx] *= inv_acc; \
idx += block_dim; \ idx += block_dim; \
} \ } \
} \ }
REDUCE(Sum, fast_sum_f32, float)
REDUCE(Sum, fast_sum_u32, uint)
REDUCE(Sum, fast_sum_f16, half)
REDUCE(Sum, fast_sum_u8, uint8_t)
REDUCE(Mul, fast_mul_f32, float)
REDUCE(Mul, fast_mul_u32, uint)
REDUCE(Mul, fast_mul_f16, half)
REDUCE(Max, fast_max_f32, float)
REDUCE(Max, fast_max_u32, uint)
REDUCE(Max, fast_max_f16, half)
REDUCE(Max, fast_max_u8, uint8_t)
REDUCE(Min, fast_min_f32, float)
REDUCE(Min, fast_min_u32, uint)
REDUCE(Min, fast_min_f16, half)
REDUCE(Min, fast_min_u8, uint8_t)
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
@ -288,18 +497,21 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half) SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 220 #if __METAL_VERSION__ >= 220
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) REDUCE(Sum, fast_sum_i64, int64_t)
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX) REDUCE(Mul, fast_mul_i64, int64_t)
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN) REDUCE(Min, fast_min_i64, int64_t)
REDUCE(Max, fast_max_i64, int64_t)
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
#endif #endif
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(Sum, fast_sum_bf16, bfloat)
REDUCE(x * y, fast_mul_bf16, bfloat, 1) REDUCE(Mul, fast_mul_bf16, bfloat)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) REDUCE(Max, fast_max_bf16, bfloat)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) REDUCE(Min, fast_min_bf16, bfloat)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat) SOFTMAX(softmax_bf16, bfloat)

View File

@ -509,7 +509,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
let dims = vec![v.len()]; let dims = vec![v.len()];
let strides = vec![1]; let strides = vec![1];
call_reduce_strided( let result = call_reduce_strided(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
@ -520,10 +520,17 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
&input, &input,
0, 0,
&output, &output,
) );
.unwrap(); match result {
command_buffer.commit(); Ok(_) => {
command_buffer.wait_until_completed(); command_buffer.commit();
command_buffer.wait_until_completed();
}
Err(e) => {
println!("Error: {}", e);
panic!("damn!");
},
}
read_to_vec(&output, out_length) read_to_vec(&output, out_length)
} }