Finish reduce kernels.

This commit is contained in:
Nicolas Patry
2023-12-17 19:07:00 +01:00
parent 6bc92e63cb
commit 972903021c
6 changed files with 258 additions and 39 deletions

View File

@ -482,20 +482,9 @@ impl BackendStorage for MetalStorage {
} }
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
if sum_dims.len() != 1 {
crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet.");
}
if sum_dims[0] != layout.shape().rank() - 1 {
crate::bail!("Non last dim reduce op {op:?} not implemented yet");
}
if layout.stride()[sum_dims[0]] != 1 {
crate::bail!("Non contiguous reduce op {op:?} not implemented yet");
}
let device = self.device.clone(); let device = self.device.clone();
let src_stride = layout.stride(); let src_stride = layout.stride();
let src_dims = layout.shape().dims(); let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
// Source dims and strides with the sum dims at the end. // Source dims and strides with the sum dims at the end.
let mut dims = vec![]; let mut dims = vec![];
let mut stride = vec![]; let mut stride = vec![];
@ -515,28 +504,41 @@ impl BackendStorage for MetalStorage {
// The reduction loop requires the shared array to be properly initialized and for // The reduction loop requires the shared array to be properly initialized and for
// this we want the number of threads to be a power of two. // this we want the number of threads to be a power of two.
let (name, check_empty, return_index) = match (op, self.dtype) { let (name, check_empty, return_index) = match (op, self.dtype) {
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), (ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), (ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), (ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
_ => crate::bail!("Reduce op for non float"), (ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
(ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
(ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
(ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
(ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
(k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"),
}; };
if check_empty && layout.shape().elem_count() == 0 { if check_empty && layout.shape().elem_count() == 0 {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
} }
let dtype = if return_index { DType::U32 } else { self.dtype }; let dtype = if return_index { DType::U32 } else { self.dtype };
if dtype == DType::U32 {
crate::bail!("reduce op {name} is not implemented yet.");
}
let buffer = device.new_buffer(dst_el, dtype, "reduce")?; let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
candle_metal_kernels::call_reduce_contiguous( candle_metal_kernels::call_reduce_strided(
&device.device, &device.device,
&command_buffer, &command_buffer,
&device.kernels, &device.kernels,
name, name,
src_el, &dims,
&stride,
dst_el, dst_el,
&self.buffer, &self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(), layout.start_offset() * self.dtype.size_in_bytes(),
@ -730,7 +732,7 @@ impl BackendStorage for MetalStorage {
("sub", DType::F16) => contiguous::sub::HALF, ("sub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF, ("mul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF, ("div", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"), (name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_binary_contiguous( candle_metal_kernels::call_binary_contiguous(
&device.device, &device.device,
@ -751,11 +753,15 @@ impl BackendStorage for MetalStorage {
("bsub", DType::F32) => strided::sub::FLOAT, ("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT, ("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT, ("bdiv", DType::F32) => strided::div::FLOAT,
("bminimum", DType::F32) => strided::min::FLOAT,
("bmaximum", DType::F32) => strided::max::FLOAT,
("badd", DType::F16) => strided::add::HALF, ("badd", DType::F16) => strided::add::HALF,
("bsub", DType::F16) => strided::sub::HALF, ("bsub", DType::F16) => strided::sub::HALF,
("bmul", DType::F16) => strided::mul::HALF, ("bmul", DType::F16) => strided::mul::HALF,
("bdiv", DType::F16) => strided::div::HALF, ("bdiv", DType::F16) => strided::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"), ("bminimum", DType::F16) => strided::min::HALF,
("bmaximum", DType::F16) => strided::max::HALF,
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
}; };
candle_metal_kernels::call_binary_strided( candle_metal_kernels::call_binary_strided(
&device.device, &device.device,

View File

@ -543,6 +543,7 @@ fn argmax(device: &Device) -> Result<()> {
let t1 = tensor.reshape((190, 5, 4))?; let t1 = tensor.reshape((190, 5, 4))?;
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?; let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
for tensor in [t1, t2] { for tensor in [t1, t2] {
println!("{}", tensor.argmax_keepdim(0)?.argmax_keepdim(2)?);
assert_eq!( assert_eq!(
tensor tensor
.argmax_keepdim(0)? .argmax_keepdim(0)?

View File

@ -1,5 +1,8 @@
#include <metal_stdlib> #include <metal_stdlib>
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(
uint idx, uint idx,
constant size_t &num_dims, constant size_t &num_dims,
@ -63,10 +66,14 @@ BINARY_OP(x + y, add)
BINARY_OP(x - y, sub) BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul) BINARY_OP(x * y, mul)
BINARY_OP(x / y, div) BINARY_OP(x / y, div)
BINARY_OP(MIN(x, y), min)
BINARY_OP(MAX(x, y), max)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add) BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub) BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul) BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div) BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
#endif #endif

View File

@ -166,7 +166,7 @@ pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh); ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
} }
pub mod binary { pub mod binary {
ops!(add, sub, mul, div); ops!(add, sub, mul, div, min, max);
} }
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -588,6 +588,64 @@ pub fn call_reduce_contiguous(
Ok(()) Ok(())
} }
pub fn call_reduce_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
shape: &[usize],
strides: &[usize],
out_length: usize,
input: &Buffer,
input_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let length: usize = shape.iter().product();
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
shape.len(),
shape,
strides,
elements_to_sum,
(input, input_offset),
output
)
);
let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn call_last_softmax( pub fn call_last_softmax(
device: &Device, device: &Device,

View File

@ -2,6 +2,7 @@
using namespace metal; using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y)) #define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(
uint idx, uint idx,
@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index(
constant int THREADGROUP_SIZE = 2048; constant int THREADGROUP_SIZE = 2048;
# define REDUCE(FN, NAME, T) \
#define ARGMIN(NAME, T, MAXVALUE) \
kernel void NAME( \ kernel void NAME( \
constant size_t &src_numel, \ constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \ constant size_t &el_to_sum_per_block, \
device const T *src, \ device const T *src, \
device T *dst, \ device T *dst, \
@ -34,21 +156,21 @@ kernel void NAME( \
\ \
threadgroup T shared_memory[THREADGROUP_SIZE]; \ threadgroup T shared_memory[THREADGROUP_SIZE]; \
\ \
shared_memory[tid] = 0; \ shared_memory[tid] = START; \
/* \ /* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \ // Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \ // to (dst_id + 1) * el_to_sum_per_block. \
*/ \ */ \
size_t start_idx = dst_id * el_to_sum_per_block; \ size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \ size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \ size_t idx = start_idx + tid; \
while (idx < stop_idx) { \ while (idx < stop_idx) { \
/* \ /* \
// TODO: Fast version for the contiguous case. \ // TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \ */ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \ T x = shared_memory[tid]; \
T y = src[idx]; \ T y = src[strided_i]; \
shared_memory[tid] = FN; \ shared_memory[tid] = FN; \
idx += block_dim; \ idx += block_dim; \
} \ } \
@ -71,10 +193,6 @@ kernel void NAME( \
} \ } \
REDUCE(x + y, fast_sum_f32, float)
REDUCE(x * y, fast_mul_f32, float)
REDUCE(max(x, y), fast_max_f32, float)
#define SOFTMAX(NAME, T) \ #define SOFTMAX(NAME, T) \
kernel void NAME( \ kernel void NAME( \
constant size_t &src_numel, \ constant size_t &src_numel, \
@ -142,8 +260,33 @@ kernel void NAME(
} \ } \
} \ } \
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half) SOFTMAX(softmax_f16, half)
#if __METAL_VERSION__ >= 310 #if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat) SOFTMAX(softmax_bf16, bfloat)
#endif #endif

View File

@ -574,12 +574,16 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
call_reduce_contiguous( let num_dims = 1;
let dims = vec![v.len()];
let strides = vec![1];
call_reduce_strided(
&device, &device,
command_buffer, command_buffer,
&kernels, &kernels,
name, name,
v.len(), &dims,
&strides,
out_length, out_length,
&input, &input,
0, 0,
@ -623,7 +627,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1; let out_length = 1;
let results = run_reduce(&v, out_length, "fast_sum_f32"); let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![21.0]); assert_eq!(approx(results, 4), vec![21.0]);
} }
@ -632,7 +636,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2; let out_length = 2;
let results = run_reduce(&v, out_length, "fast_sum_f32"); let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
assert_eq!(approx(results, 4), vec![6.0, 15.0]); assert_eq!(approx(results, 4), vec![6.0, 15.0]);
} }