mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Improve reduce perf and add contiguous impl
This commit is contained in:
@ -1,4 +1,4 @@
|
|||||||
mod benchmarks;
|
mod benchmarks;
|
||||||
|
|
||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
criterion_main!(benchmarks::matmul::benches);
|
criterion_main!(benchmarks::reduce::benches);
|
@ -1,4 +1,5 @@
|
|||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
|
pub(crate) mod reduce;
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
use candle_core::{Device, Result};
|
||||||
|
|
||||||
|
36
candle-core/benches/benchmarks/reduce.rs
Normal file
36
candle-core/benches/benchmarks/reduce.rs
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
use candle_core::{DType, Tensor};
|
||||||
|
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||||
|
use std::time::Instant;
|
||||||
|
use crate::benchmarks::{bench_name, device, BenchDevice};
|
||||||
|
|
||||||
|
fn run(a: &Tensor) {
|
||||||
|
a.sum(2).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
|
let b = 1;
|
||||||
|
let m = 2048;
|
||||||
|
let k = 2048;
|
||||||
|
|
||||||
|
let device = device().unwrap();
|
||||||
|
|
||||||
|
let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap();
|
||||||
|
|
||||||
|
let flops = b * m * k * DType::F32.size_in_bytes();
|
||||||
|
|
||||||
|
let mut group = c.benchmark_group(bench_name("reduce"));
|
||||||
|
group.throughput(Throughput::Bytes(flops as u64));
|
||||||
|
group.bench_function("iter", move |b| {
|
||||||
|
b.iter_custom(|iters| {
|
||||||
|
let start = Instant::now();
|
||||||
|
for _i in 0..iters {
|
||||||
|
run(black_box(&a));
|
||||||
|
}
|
||||||
|
device.sync().unwrap();
|
||||||
|
start.elapsed()
|
||||||
|
})
|
||||||
|
});
|
||||||
|
group.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
criterion_group!(benches, criterion_benchmark);
|
@ -491,6 +491,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
|
|
||||||
let src_stride = layout.stride();
|
let src_stride = layout.stride();
|
||||||
let src_dims = layout.shape().dims();
|
let src_dims = layout.shape().dims();
|
||||||
// Source dims and strides with the sum dims at the end.
|
// Source dims and strides with the sum dims at the end.
|
||||||
@ -504,13 +505,72 @@ impl BackendStorage for MetalStorage {
|
|||||||
stride.push(src_stride[dim_idx]);
|
stride.push(src_stride[dim_idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if layout.is_contiguous() {
|
||||||
|
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||||
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||||
|
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||||
|
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||||
|
//(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||||
|
//(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||||
|
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false),
|
||||||
|
(ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false),
|
||||||
|
(ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false),
|
||||||
|
//(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true),
|
||||||
|
//(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true),
|
||||||
|
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false),
|
||||||
|
(ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false),
|
||||||
|
(ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false),
|
||||||
|
//(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true),
|
||||||
|
//(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true),
|
||||||
|
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false),
|
||||||
|
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false),
|
||||||
|
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false),
|
||||||
|
//(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true),
|
||||||
|
//(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true),
|
||||||
|
(ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false),
|
||||||
|
(ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false),
|
||||||
|
(ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false),
|
||||||
|
//(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true),
|
||||||
|
//(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true),
|
||||||
|
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false),
|
||||||
|
(ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false),
|
||||||
|
(ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false),
|
||||||
|
//(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true),
|
||||||
|
//(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true),
|
||||||
|
//(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||||
|
_ => ("fall back to strided impl", false, false)
|
||||||
|
};
|
||||||
|
|
||||||
|
if name != "fall back to strided impl" {
|
||||||
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
let buffer = device.new_buffer(1, self.dtype, "reduce")?;
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
name,
|
||||||
|
layout.shape().elem_count(),
|
||||||
|
dst_el,
|
||||||
|
&self.buffer,
|
||||||
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
return Ok(Self::new(buffer, device, self.dtype));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for &dim_idx in sum_dims.iter() {
|
for &dim_idx in sum_dims.iter() {
|
||||||
dims.push(src_dims[dim_idx]);
|
dims.push(src_dims[dim_idx]);
|
||||||
stride.push(src_stride[dim_idx]);
|
stride.push(src_stride[dim_idx]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The reduction loop requires the shared array to be properly initialized and for
|
|
||||||
// this we want the number of threads to be a power of two.
|
|
||||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||||
|
@ -568,7 +568,6 @@ pub fn call_reduce_contiguous(
|
|||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -597,7 +596,6 @@ pub fn call_reduce_contiguous(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -619,7 +617,6 @@ pub fn call_reduce_strided(
|
|||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.wait_for_fence(&kernels.fence);
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
@ -630,7 +627,8 @@ pub fn call_reduce_strided(
|
|||||||
strides,
|
strides,
|
||||||
elements_to_sum,
|
elements_to_sum,
|
||||||
(input, input_offset),
|
(input, input_offset),
|
||||||
output
|
output,
|
||||||
|
out_length
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -655,7 +653,6 @@ pub fn call_reduce_strided(
|
|||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.update_fence(&kernels.fence);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
#include <metal_limits>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
|
||||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant const size_t &num_dims,
|
||||||
constant size_t *dims,
|
constant const size_t *dims,
|
||||||
constant size_t *strides
|
constant const size_t *strides
|
||||||
) {
|
) {
|
||||||
uint strided_i = 0;
|
uint strided_i = 0;
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
uint dim_idx = num_dims - 1 - d;
|
uint dim_idx = num_dims - 1 - d;
|
||||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
@ -19,8 +18,270 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
constant int THREADGROUP_SIZE = 2048;
|
#define impl_reduction_op(name, op, init_val) \
|
||||||
|
template<typename T> \
|
||||||
|
struct name { \
|
||||||
|
\
|
||||||
|
static constexpr constant T init = init_val; \
|
||||||
|
\
|
||||||
|
METAL_FUNC T operator()(thread const T &a, thread const T &b) const { \
|
||||||
|
return op; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC T operator()(threadgroup const T &a, threadgroup const T &b) const { \
|
||||||
|
return op; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC T operator()(device const T &a, device const T &b) const { \
|
||||||
|
return op; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
METAL_FUNC T operator()(T a, T b) { \
|
||||||
|
return op; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
impl_reduction_op(Sum, a + b, 0);
|
||||||
|
impl_reduction_op(Mul, a * b, 1);
|
||||||
|
impl_reduction_op(Min, a < b ? a : b, numeric_limits<T>::max());
|
||||||
|
impl_reduction_op(Max, a > b ? a : b, numeric_limits<T>::min());
|
||||||
|
#undef impl_reduction_op
|
||||||
|
|
||||||
|
static constant constexpr int THREADGROUP_SIZE = 2048;
|
||||||
|
|
||||||
|
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||||
|
METAL_FUNC void load_from_global(
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const T *src,
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ReductionOp op;
|
||||||
|
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides);
|
||||||
|
shared[tid] = op(shared[tid], src[strided_i]);
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||||
|
METAL_FUNC void load_from_global(
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const T *src,
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ReductionOp op;
|
||||||
|
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block;
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
shared[tid] = op(shared[tid], src[idx]);
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename ReductionOp, uint BLOCKSIZE>
|
||||||
|
METAL_FUNC void threadgroup_reduce(
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
uint tid [[thread_index_in_threadgroup]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ReductionOp op;
|
||||||
|
if (BLOCKSIZE >= 64) {
|
||||||
|
if (block_dim >= 64) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 32]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 32) {
|
||||||
|
if (block_dim >= 32) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 16]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 16) {
|
||||||
|
if (block_dim >= 16) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 8]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 8) {
|
||||||
|
if (block_dim >= 8) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 4) {
|
||||||
|
if (block_dim >= 4) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 2]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 2) {
|
||||||
|
if (block_dim >= 2) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris
|
||||||
|
template<
|
||||||
|
typename T,
|
||||||
|
typename ReductionOp,
|
||||||
|
uint BLOCKSIZE,
|
||||||
|
bool STRIDED
|
||||||
|
>
|
||||||
|
METAL_FUNC void block_reduce(
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides,
|
||||||
|
constant size_t &el_to_sum_per_block,
|
||||||
|
device const T *src,
|
||||||
|
device T *dst,
|
||||||
|
constant uint &num_elements,
|
||||||
|
threadgroup T shared[BLOCKSIZE],
|
||||||
|
uint id [[ thread_position_in_grid ]],
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]],
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]],
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]]
|
||||||
|
) {
|
||||||
|
ReductionOp op;
|
||||||
|
|
||||||
|
shared[tid] = ReductionOp::init;
|
||||||
|
|
||||||
|
if (STRIDED) {
|
||||||
|
load_from_global<T, ReductionOp, BLOCKSIZE>(
|
||||||
|
num_dims,
|
||||||
|
dims,
|
||||||
|
strides,
|
||||||
|
el_to_sum_per_block,
|
||||||
|
src,
|
||||||
|
shared,
|
||||||
|
tid,
|
||||||
|
dst_id,
|
||||||
|
block_dim
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
load_from_global<T, ReductionOp, BLOCKSIZE>(
|
||||||
|
num_dims,
|
||||||
|
dims,
|
||||||
|
el_to_sum_per_block,
|
||||||
|
src,
|
||||||
|
shared,
|
||||||
|
tid,
|
||||||
|
dst_id,
|
||||||
|
block_dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (BLOCKSIZE >= 1024) {
|
||||||
|
if (tid < 512 && block_dim >= 1024) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 512]);
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 512) {
|
||||||
|
if (tid < 256 && block_dim >= 512) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 256]);
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 256) {
|
||||||
|
if (tid < 128 && block_dim >= 256) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 128]);
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (BLOCKSIZE >= 128) {
|
||||||
|
if (tid < 64 && block_dim >= 128) {
|
||||||
|
shared[tid] = op(shared[tid], shared[tid + 64]);
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (tid < 32) {
|
||||||
|
threadgroup_reduce<T, ReductionOp, BLOCKSIZE>(shared, tid, block_dim);
|
||||||
|
threadgroup_barrier(mem_flags::mem_none);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
dst[dst_id] = shared[tid];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
|
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||||
|
|
||||||
|
static constant constexpr int BLOCKSIZE = 2048;
|
||||||
|
|
||||||
|
#define REDUCE(OP, NAME, T) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
constant uint &num_elements, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
threadgroup T shared[BLOCKSIZE]; \
|
||||||
|
block_reduce<T, OP<T>, BLOCKSIZE, false>( \
|
||||||
|
num_dims, \
|
||||||
|
dims, \
|
||||||
|
strides, \
|
||||||
|
el_to_sum_per_block, \
|
||||||
|
src, \
|
||||||
|
dst, \
|
||||||
|
num_elements, \
|
||||||
|
shared, \
|
||||||
|
id, \
|
||||||
|
tid, \
|
||||||
|
dst_id, \
|
||||||
|
block_dim); \
|
||||||
|
} \
|
||||||
|
kernel void NAME##_strided( \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
constant uint &num_elements, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
threadgroup T shared[BLOCKSIZE]; \
|
||||||
|
block_reduce<T, OP<T>, BLOCKSIZE, false>( \
|
||||||
|
num_dims, \
|
||||||
|
dims, \
|
||||||
|
strides, \
|
||||||
|
el_to_sum_per_block, \
|
||||||
|
src, \
|
||||||
|
dst, \
|
||||||
|
num_elements, \
|
||||||
|
shared, \
|
||||||
|
id, \
|
||||||
|
tid, \
|
||||||
|
dst_id, \
|
||||||
|
block_dim); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
@ -140,59 +401,6 @@ kernel void NAME( \
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
#define REDUCE(FN, NAME, T, START) \
|
|
||||||
kernel void NAME( \
|
|
||||||
constant size_t &num_dims, \
|
|
||||||
constant size_t *dims, \
|
|
||||||
constant size_t *strides, \
|
|
||||||
constant size_t &el_to_sum_per_block, \
|
|
||||||
device const T *src, \
|
|
||||||
device T *dst, \
|
|
||||||
uint id [[ thread_position_in_grid ]], \
|
|
||||||
uint tid [[ thread_index_in_threadgroup ]], \
|
|
||||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
|
||||||
uint block_dim [[ threads_per_threadgroup ]] \
|
|
||||||
) { \
|
|
||||||
\
|
|
||||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
|
||||||
\
|
|
||||||
shared_memory[tid] = START; \
|
|
||||||
/* \
|
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
|
||||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
|
||||||
*/ \
|
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
|
||||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
|
||||||
size_t idx = start_idx + tid; \
|
|
||||||
while (idx < stop_idx) { \
|
|
||||||
/* \
|
|
||||||
// TODO: Fast version for the contiguous case. \
|
|
||||||
*/ \
|
|
||||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
|
||||||
T x = shared_memory[tid]; \
|
|
||||||
T y = src[strided_i]; \
|
|
||||||
shared_memory[tid] = FN; \
|
|
||||||
idx += block_dim; \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
\
|
|
||||||
/* \
|
|
||||||
// reduction in shared memory \
|
|
||||||
*/ \
|
|
||||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
|
||||||
if (tid < s) { \
|
|
||||||
T x = shared_memory[tid]; \
|
|
||||||
T y = shared_memory[tid + s]; \
|
|
||||||
shared_memory[tid] = FN; \
|
|
||||||
} \
|
|
||||||
threadgroup_barrier(mem_flags::mem_none); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
dst[dst_id] = shared_memory[0]; \
|
|
||||||
} \
|
|
||||||
|
|
||||||
|
|
||||||
#define SOFTMAX(NAME, T) \
|
#define SOFTMAX(NAME, T) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &src_numel, \
|
constant size_t &src_numel, \
|
||||||
@ -258,23 +466,24 @@ kernel void NAME(
|
|||||||
dst[idx] *= inv_acc; \
|
dst[idx] *= inv_acc; \
|
||||||
idx += block_dim; \
|
idx += block_dim; \
|
||||||
} \
|
} \
|
||||||
} \
|
}
|
||||||
|
|
||||||
|
REDUCE(Sum, fast_sum_f32, float)
|
||||||
|
REDUCE(Sum, fast_sum_u32, uint)
|
||||||
|
REDUCE(Sum, fast_sum_f16, half)
|
||||||
|
REDUCE(Sum, fast_sum_u8, uint8_t)
|
||||||
|
REDUCE(Mul, fast_mul_f32, float)
|
||||||
|
REDUCE(Mul, fast_mul_u32, uint)
|
||||||
|
REDUCE(Mul, fast_mul_f16, half)
|
||||||
|
REDUCE(Max, fast_max_f32, float)
|
||||||
|
REDUCE(Max, fast_max_u32, uint)
|
||||||
|
REDUCE(Max, fast_max_f16, half)
|
||||||
|
REDUCE(Max, fast_max_u8, uint8_t)
|
||||||
|
REDUCE(Min, fast_min_f32, float)
|
||||||
|
REDUCE(Min, fast_min_u32, uint)
|
||||||
|
REDUCE(Min, fast_min_f16, half)
|
||||||
|
REDUCE(Min, fast_min_u8, uint8_t)
|
||||||
|
|
||||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
|
||||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
|
||||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
|
||||||
REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0)
|
|
||||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
|
||||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
|
||||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
|
||||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
|
||||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
|
||||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
|
||||||
REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0)
|
|
||||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
|
||||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
|
||||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
|
||||||
REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF)
|
|
||||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||||
@ -288,18 +497,21 @@ SOFTMAX(softmax_f32, float)
|
|||||||
SOFTMAX(softmax_f16, half)
|
SOFTMAX(softmax_f16, half)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 220
|
#if __METAL_VERSION__ >= 220
|
||||||
REDUCE(x + y, fast_sum_i64_strided, int64_t, 0)
|
REDUCE(Sum, fast_sum_i64, int64_t)
|
||||||
REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX)
|
REDUCE(Mul, fast_mul_i64, int64_t)
|
||||||
REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN)
|
REDUCE(Min, fast_min_i64, int64_t)
|
||||||
|
REDUCE(Max, fast_max_i64, int64_t)
|
||||||
|
|
||||||
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX)
|
||||||
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
REDUCE(Sum, fast_sum_bf16, bfloat)
|
||||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
REDUCE(Mul, fast_mul_bf16, bfloat)
|
||||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
REDUCE(Max, fast_max_bf16, bfloat)
|
||||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
REDUCE(Min, fast_min_bf16, bfloat)
|
||||||
|
|
||||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||||
SOFTMAX(softmax_bf16, bfloat)
|
SOFTMAX(softmax_bf16, bfloat)
|
||||||
|
@ -509,7 +509,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||||
let dims = vec![v.len()];
|
let dims = vec![v.len()];
|
||||||
let strides = vec![1];
|
let strides = vec![1];
|
||||||
call_reduce_strided(
|
let result = call_reduce_strided(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
@ -520,10 +520,17 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
&input,
|
&input,
|
||||||
0,
|
0,
|
||||||
&output,
|
&output,
|
||||||
)
|
);
|
||||||
.unwrap();
|
match result {
|
||||||
command_buffer.commit();
|
Ok(_) => {
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
println!("Error: {}", e);
|
||||||
|
panic!("damn!");
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
read_to_vec(&output, out_length)
|
read_to_vec(&output, out_length)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user