mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Finish reduce kernels.
This commit is contained in:
@ -482,20 +482,9 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
if sum_dims.len() != 1 {
|
||||
crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet.");
|
||||
}
|
||||
if sum_dims[0] != layout.shape().rank() - 1 {
|
||||
crate::bail!("Non last dim reduce op {op:?} not implemented yet");
|
||||
}
|
||||
if layout.stride()[sum_dims[0]] != 1 {
|
||||
crate::bail!("Non contiguous reduce op {op:?} not implemented yet");
|
||||
}
|
||||
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
let src_el: usize = src_dims.iter().product();
|
||||
// Source dims and strides with the sum dims at the end.
|
||||
let mut dims = vec![];
|
||||
let mut stride = vec![];
|
||||
@ -515,28 +504,41 @@ impl BackendStorage for MetalStorage {
|
||||
// The reduction loop requires the shared array to be properly initialized and for
|
||||
// this we want the number of threads to be a power of two.
|
||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
||||
_ => crate::bail!("Reduce op for non float"),
|
||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
|
||||
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
|
||||
(ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
|
||||
(ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
|
||||
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
|
||||
(ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
|
||||
(ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
|
||||
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
|
||||
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
|
||||
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
if dtype == DType::U32 {
|
||||
crate::bail!("reduce op {name} is not implemented yet.");
|
||||
}
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
candle_metal_kernels::call_reduce_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
src_el,
|
||||
&dims,
|
||||
&stride,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
@ -730,7 +732,7 @@ impl BackendStorage for MetalStorage {
|
||||
("sub", DType::F16) => contiguous::sub::HALF,
|
||||
("mul", DType::F16) => contiguous::mul::HALF,
|
||||
("div", DType::F16) => contiguous::div::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
&device.device,
|
||||
@ -751,11 +753,15 @@ impl BackendStorage for MetalStorage {
|
||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||
("bminimum", DType::F32) => strided::min::FLOAT,
|
||||
("bmaximum", DType::F32) => strided::max::FLOAT,
|
||||
("badd", DType::F16) => strided::add::HALF,
|
||||
("bsub", DType::F16) => strided::sub::HALF,
|
||||
("bmul", DType::F16) => strided::mul::HALF,
|
||||
("bdiv", DType::F16) => strided::div::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
("bminimum", DType::F16) => strided::min::HALF,
|
||||
("bmaximum", DType::F16) => strided::max::HALF,
|
||||
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_strided(
|
||||
&device.device,
|
||||
|
@ -543,6 +543,7 @@ fn argmax(device: &Device) -> Result<()> {
|
||||
let t1 = tensor.reshape((190, 5, 4))?;
|
||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||
for tensor in [t1, t2] {
|
||||
println!("{}", tensor.argmax_keepdim(0)?.argmax_keepdim(2)?);
|
||||
assert_eq!(
|
||||
tensor
|
||||
.argmax_keepdim(0)?
|
||||
|
@ -1,5 +1,8 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
@ -63,10 +66,14 @@ BINARY_OP(x + y, add)
|
||||
BINARY_OP(x - y, sub)
|
||||
BINARY_OP(x * y, mul)
|
||||
BINARY_OP(x / y, div)
|
||||
BINARY_OP(MIN(x, y), min)
|
||||
BINARY_OP(MAX(x, y), max)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
BFLOAT_BINARY_OP(x + y, add)
|
||||
BFLOAT_BINARY_OP(x - y, sub)
|
||||
BFLOAT_BINARY_OP(x * y, mul)
|
||||
BFLOAT_BINARY_OP(x / y, div)
|
||||
BFLOAT_BINARY_OP(MIN(x, y), min)
|
||||
BFLOAT_BINARY_OP(MAX(x, y), max)
|
||||
#endif
|
||||
|
@ -166,7 +166,7 @@ pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
ops!(add, sub, mul, div, min, max);
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -588,6 +588,64 @@ pub fn call_reduce_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_reduce_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let length: usize = shape.iter().product();
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.wait_for_fence(&kernels.fence);
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
shape.len(),
|
||||
shape,
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
elements_to_sum as u64,
|
||||
)
|
||||
.next_power_of_two();
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.update_fence(&kernels.fence);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_last_softmax(
|
||||
device: &Device,
|
||||
|
@ -2,6 +2,7 @@
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index(
|
||||
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
# define REDUCE(FN, NAME, T) \
|
||||
|
||||
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MAXVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
bool notset = true; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
|
||||
#define ARGMAX(NAME, T, MINVALUE) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device uint *dst, \
|
||||
uint id [[ thread_position_in_grid ]], \
|
||||
uint tid [[ thread_index_in_threadgroup ]], \
|
||||
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||
uint block_dim [[ threads_per_threadgroup ]] \
|
||||
) { \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = MINVALUE; \
|
||||
shared_indices[tid] = 0xFFFFFFFF; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
bool notset = true; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||
shared_memory[tid] = src[strided_i]; \
|
||||
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||
notset = false; \
|
||||
} \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
\
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
\
|
||||
/* \
|
||||
// reduction in shared memory \
|
||||
*/ \
|
||||
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||
shared_indices[tid] = shared_indices[tid + s]; \
|
||||
shared_memory[tid] = shared_memory[tid + s]; \
|
||||
} \
|
||||
threadgroup_barrier(mem_flags::mem_none); \
|
||||
} \
|
||||
\
|
||||
if (tid == 0){ \
|
||||
dst[dst_id] = shared_indices[0]; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define REDUCE(FN, NAME, T, START) \
|
||||
kernel void NAME( \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant size_t &el_to_sum_per_block, \
|
||||
device const T *src, \
|
||||
device T *dst, \
|
||||
@ -34,21 +156,21 @@ kernel void NAME( \
|
||||
\
|
||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||
\
|
||||
shared_memory[tid] = 0; \
|
||||
shared_memory[tid] = START; \
|
||||
/* \
|
||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||
*/ \
|
||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
||||
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||
size_t idx = start_idx + tid; \
|
||||
while (idx < stop_idx) { \
|
||||
/* \
|
||||
// TODO: Fast version for the contiguous case. \
|
||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
*/ \
|
||||
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||
T x = shared_memory[tid]; \
|
||||
T y = src[idx]; \
|
||||
T y = src[strided_i]; \
|
||||
shared_memory[tid] = FN; \
|
||||
idx += block_dim; \
|
||||
} \
|
||||
@ -71,10 +193,6 @@ kernel void NAME( \
|
||||
} \
|
||||
|
||||
|
||||
REDUCE(x + y, fast_sum_f32, float)
|
||||
REDUCE(x * y, fast_mul_f32, float)
|
||||
REDUCE(max(x, y), fast_max_f32, float)
|
||||
|
||||
#define SOFTMAX(NAME, T) \
|
||||
kernel void NAME( \
|
||||
constant size_t &src_numel, \
|
||||
@ -142,8 +260,33 @@ kernel void NAME(
|
||||
} \
|
||||
} \
|
||||
|
||||
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||
|
||||
SOFTMAX(softmax_f32, float)
|
||||
SOFTMAX(softmax_f16, half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||
SOFTMAX(softmax_bf16, bfloat)
|
||||
#endif
|
||||
|
@ -574,12 +574,16 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||
call_reduce_contiguous(
|
||||
let num_dims = 1;
|
||||
let dims = vec![v.len()];
|
||||
let strides = vec![1];
|
||||
call_reduce_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&dims,
|
||||
&strides,
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
@ -623,7 +627,7 @@ fn reduce_sum() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 1;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32");
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![21.0]);
|
||||
}
|
||||
|
||||
@ -632,7 +636,7 @@ fn reduce_sum2() {
|
||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||
let out_length = 2;
|
||||
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32");
|
||||
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user