mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Finish reduce kernels.
This commit is contained in:
@ -482,20 +482,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
if sum_dims.len() != 1 {
|
|
||||||
crate::bail!("reduce {op:?} over multiple dimensions is not implemented yet.");
|
|
||||||
}
|
|
||||||
if sum_dims[0] != layout.shape().rank() - 1 {
|
|
||||||
crate::bail!("Non last dim reduce op {op:?} not implemented yet");
|
|
||||||
}
|
|
||||||
if layout.stride()[sum_dims[0]] != 1 {
|
|
||||||
crate::bail!("Non contiguous reduce op {op:?} not implemented yet");
|
|
||||||
}
|
|
||||||
|
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
let src_stride = layout.stride();
|
let src_stride = layout.stride();
|
||||||
let src_dims = layout.shape().dims();
|
let src_dims = layout.shape().dims();
|
||||||
let src_el: usize = src_dims.iter().product();
|
|
||||||
// Source dims and strides with the sum dims at the end.
|
// Source dims and strides with the sum dims at the end.
|
||||||
let mut dims = vec![];
|
let mut dims = vec![];
|
||||||
let mut stride = vec![];
|
let mut stride = vec![];
|
||||||
@ -515,28 +504,41 @@ impl BackendStorage for MetalStorage {
|
|||||||
// The reduction loop requires the shared array to be properly initialized and for
|
// The reduction loop requires the shared array to be properly initialized and for
|
||||||
// this we want the number of threads to be a power of two.
|
// this we want the number of threads to be a power of two.
|
||||||
let (name, check_empty, return_index) = match (op, self.dtype) {
|
let (name, check_empty, return_index) = match (op, self.dtype) {
|
||||||
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false),
|
(ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false),
|
||||||
(ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false),
|
(ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false),
|
||||||
(ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false),
|
(ReduceOp::Max, DType::F32) => ("fast_max_f32_strided", true, false),
|
||||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true),
|
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32_strided", true, true),
|
||||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true),
|
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32_strided", true, true),
|
||||||
_ => crate::bail!("Reduce op for non float"),
|
(ReduceOp::Sum, DType::U32) => ("fast_sum_u32_strided", false, false),
|
||||||
|
(ReduceOp::Min, DType::U32) => ("fast_min_u32_strided", true, false),
|
||||||
|
(ReduceOp::Max, DType::U32) => ("fast_max_u32_strided", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32_strided", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32_strided", true, true),
|
||||||
|
(ReduceOp::Sum, DType::F16) => ("fast_sum_f16_strided", false, false),
|
||||||
|
(ReduceOp::Min, DType::F16) => ("fast_min_f16_strided", true, false),
|
||||||
|
(ReduceOp::Max, DType::F16) => ("fast_max_f16_strided", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16_strided", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16_strided", true, true),
|
||||||
|
(ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16_strided", false, false),
|
||||||
|
(ReduceOp::Min, DType::BF16) => ("fast_min_bf16_strided", true, false),
|
||||||
|
(ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false),
|
||||||
|
(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true),
|
||||||
|
(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true),
|
||||||
|
(k, dtype) => crate::bail!("Reduce op for non float {k:?} {dtype:?}"),
|
||||||
};
|
};
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
}
|
}
|
||||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
if dtype == DType::U32 {
|
|
||||||
crate::bail!("reduce op {name} is not implemented yet.");
|
|
||||||
}
|
|
||||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
name,
|
||||||
src_el,
|
&dims,
|
||||||
|
&stride,
|
||||||
dst_el,
|
dst_el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
@ -730,7 +732,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("sub", DType::F16) => contiguous::sub::HALF,
|
("sub", DType::F16) => contiguous::sub::HALF,
|
||||||
("mul", DType::F16) => contiguous::mul::HALF,
|
("mul", DType::F16) => contiguous::mul::HALF,
|
||||||
("div", DType::F16) => contiguous::div::HALF,
|
("div", DType::F16) => contiguous::div::HALF,
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -751,11 +753,15 @@ impl BackendStorage for MetalStorage {
|
|||||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||||
|
("bminimum", DType::F32) => strided::min::FLOAT,
|
||||||
|
("bmaximum", DType::F32) => strided::max::FLOAT,
|
||||||
("badd", DType::F16) => strided::add::HALF,
|
("badd", DType::F16) => strided::add::HALF,
|
||||||
("bsub", DType::F16) => strided::sub::HALF,
|
("bsub", DType::F16) => strided::sub::HALF,
|
||||||
("bmul", DType::F16) => strided::mul::HALF,
|
("bmul", DType::F16) => strided::mul::HALF,
|
||||||
("bdiv", DType::F16) => strided::div::HALF,
|
("bdiv", DType::F16) => strided::div::HALF,
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
("bminimum", DType::F16) => strided::min::HALF,
|
||||||
|
("bmaximum", DType::F16) => strided::max::HALF,
|
||||||
|
(name, dtype) => crate::bail!("Binary {name} - {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_strided(
|
candle_metal_kernels::call_binary_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
|
@ -543,6 +543,7 @@ fn argmax(device: &Device) -> Result<()> {
|
|||||||
let t1 = tensor.reshape((190, 5, 4))?;
|
let t1 = tensor.reshape((190, 5, 4))?;
|
||||||
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
let t2 = t1.transpose(0, 2)?.contiguous()?.transpose(0, 2)?;
|
||||||
for tensor in [t1, t2] {
|
for tensor in [t1, t2] {
|
||||||
|
println!("{}", tensor.argmax_keepdim(0)?.argmax_keepdim(2)?);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor
|
tensor
|
||||||
.argmax_keepdim(0)?
|
.argmax_keepdim(0)?
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
|
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -63,10 +66,14 @@ BINARY_OP(x + y, add)
|
|||||||
BINARY_OP(x - y, sub)
|
BINARY_OP(x - y, sub)
|
||||||
BINARY_OP(x * y, mul)
|
BINARY_OP(x * y, mul)
|
||||||
BINARY_OP(x / y, div)
|
BINARY_OP(x / y, div)
|
||||||
|
BINARY_OP(MIN(x, y), min)
|
||||||
|
BINARY_OP(MAX(x, y), max)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_BINARY_OP(x + y, add)
|
BFLOAT_BINARY_OP(x + y, add)
|
||||||
BFLOAT_BINARY_OP(x - y, sub)
|
BFLOAT_BINARY_OP(x - y, sub)
|
||||||
BFLOAT_BINARY_OP(x * y, mul)
|
BFLOAT_BINARY_OP(x * y, mul)
|
||||||
BFLOAT_BINARY_OP(x / y, div)
|
BFLOAT_BINARY_OP(x / y, div)
|
||||||
|
BFLOAT_BINARY_OP(MIN(x, y), min)
|
||||||
|
BFLOAT_BINARY_OP(MAX(x, y), max)
|
||||||
#endif
|
#endif
|
||||||
|
@ -166,7 +166,7 @@ pub mod unary {
|
|||||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div);
|
ops!(add, sub, mul, div, min, max);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -588,6 +588,64 @@ pub fn call_reduce_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_reduce_strided(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
kernel_name: &'static str,
|
||||||
|
shape: &[usize],
|
||||||
|
strides: &[usize],
|
||||||
|
out_length: usize,
|
||||||
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let length: usize = shape.iter().product();
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
|
let elements_to_sum = length / out_length;
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.wait_for_fence(&kernels.fence);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
shape.len(),
|
||||||
|
shape,
|
||||||
|
strides,
|
||||||
|
elements_to_sum,
|
||||||
|
(input, input_offset),
|
||||||
|
output
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: out_length as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let width = std::cmp::min(
|
||||||
|
pipeline.max_total_threads_per_threadgroup(),
|
||||||
|
elements_to_sum as u64,
|
||||||
|
)
|
||||||
|
.next_power_of_two();
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.update_fence(&kernels.fence);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_last_softmax(
|
pub fn call_last_softmax(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
|
#define MIN(x, y) ((x) < (y) ? (x) : (y))
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
@ -20,9 +21,130 @@ METAL_FUNC uint get_strided_index(
|
|||||||
|
|
||||||
constant int THREADGROUP_SIZE = 2048;
|
constant int THREADGROUP_SIZE = 2048;
|
||||||
|
|
||||||
# define REDUCE(FN, NAME, T) \
|
|
||||||
|
#define ARGMIN(NAME, T, MAXVALUE) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &src_numel, \
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device uint *dst, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
\
|
||||||
|
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||||
|
\
|
||||||
|
shared_memory[tid] = MAXVALUE; \
|
||||||
|
shared_indices[tid] = 0xFFFFFFFF; \
|
||||||
|
bool notset = true; \
|
||||||
|
/* \
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
|
*/ \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
/* \
|
||||||
|
// TODO: Fast version for the contiguous case. \
|
||||||
|
*/ \
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
|
if (notset || src[strided_i] < shared_memory[tid]) { \
|
||||||
|
shared_memory[tid] = src[strided_i]; \
|
||||||
|
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
|
||||||
|
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||||
|
notset = false; \
|
||||||
|
} \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
\
|
||||||
|
/* \
|
||||||
|
// reduction in shared memory \
|
||||||
|
*/ \
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
|
||||||
|
shared_indices[tid] = shared_indices[tid + s]; \
|
||||||
|
shared_memory[tid] = shared_memory[tid + s]; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
if (tid == 0){ \
|
||||||
|
dst[dst_id] = shared_indices[0]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
|
||||||
|
#define ARGMAX(NAME, T, MINVALUE) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device uint *dst, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
\
|
||||||
|
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
|
||||||
|
\
|
||||||
|
shared_memory[tid] = MINVALUE; \
|
||||||
|
shared_indices[tid] = 0xFFFFFFFF; \
|
||||||
|
/* \
|
||||||
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
|
*/ \
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
|
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||||
|
size_t idx = start_idx + tid; \
|
||||||
|
bool notset = true; \
|
||||||
|
while (idx < stop_idx) { \
|
||||||
|
/* \
|
||||||
|
// TODO: Fast version for the contiguous case. \
|
||||||
|
*/ \
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
|
if (notset || shared_memory[tid] < src[strided_i]) { \
|
||||||
|
shared_memory[tid] = src[strided_i]; \
|
||||||
|
shared_indices[tid] = idx % dims[num_dims - 1]; \
|
||||||
|
notset = false; \
|
||||||
|
} \
|
||||||
|
idx += block_dim; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
\
|
||||||
|
/* \
|
||||||
|
// reduction in shared memory \
|
||||||
|
*/ \
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
|
||||||
|
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
|
||||||
|
shared_indices[tid] = shared_indices[tid + s]; \
|
||||||
|
shared_memory[tid] = shared_memory[tid + s]; \
|
||||||
|
} \
|
||||||
|
threadgroup_barrier(mem_flags::mem_none); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
if (tid == 0){ \
|
||||||
|
dst[dst_id] = shared_indices[0]; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
#define REDUCE(FN, NAME, T, START) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
constant size_t &el_to_sum_per_block, \
|
constant size_t &el_to_sum_per_block, \
|
||||||
device const T *src, \
|
device const T *src, \
|
||||||
device T *dst, \
|
device T *dst, \
|
||||||
@ -34,21 +156,21 @@ kernel void NAME( \
|
|||||||
\
|
\
|
||||||
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
threadgroup T shared_memory[THREADGROUP_SIZE]; \
|
||||||
\
|
\
|
||||||
shared_memory[tid] = 0; \
|
shared_memory[tid] = START; \
|
||||||
/* \
|
/* \
|
||||||
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
// Elements summed in this block range from dst_id * el_to_sum_per_block \
|
||||||
// to (dst_id + 1) * el_to_sum_per_block. \
|
// to (dst_id + 1) * el_to_sum_per_block. \
|
||||||
*/ \
|
*/ \
|
||||||
size_t start_idx = dst_id * el_to_sum_per_block; \
|
size_t start_idx = dst_id * el_to_sum_per_block; \
|
||||||
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
|
size_t stop_idx = start_idx + el_to_sum_per_block; \
|
||||||
size_t idx = start_idx + tid; \
|
size_t idx = start_idx + tid; \
|
||||||
while (idx < stop_idx) { \
|
while (idx < stop_idx) { \
|
||||||
/* \
|
/* \
|
||||||
// TODO: Fast version for the contiguous case. \
|
// TODO: Fast version for the contiguous case. \
|
||||||
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
|
||||||
*/ \
|
*/ \
|
||||||
|
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
|
||||||
T x = shared_memory[tid]; \
|
T x = shared_memory[tid]; \
|
||||||
T y = src[idx]; \
|
T y = src[strided_i]; \
|
||||||
shared_memory[tid] = FN; \
|
shared_memory[tid] = FN; \
|
||||||
idx += block_dim; \
|
idx += block_dim; \
|
||||||
} \
|
} \
|
||||||
@ -71,10 +193,6 @@ kernel void NAME( \
|
|||||||
} \
|
} \
|
||||||
|
|
||||||
|
|
||||||
REDUCE(x + y, fast_sum_f32, float)
|
|
||||||
REDUCE(x * y, fast_mul_f32, float)
|
|
||||||
REDUCE(max(x, y), fast_max_f32, float)
|
|
||||||
|
|
||||||
#define SOFTMAX(NAME, T) \
|
#define SOFTMAX(NAME, T) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &src_numel, \
|
constant size_t &src_numel, \
|
||||||
@ -142,8 +260,33 @@ kernel void NAME(
|
|||||||
} \
|
} \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
REDUCE(x + y, fast_sum_f32_strided, float, 0)
|
||||||
|
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
|
||||||
|
REDUCE(x + y, fast_sum_f16_strided, half, 0)
|
||||||
|
REDUCE(x * y, fast_mul_f32_strided, float, 1)
|
||||||
|
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
|
||||||
|
REDUCE(x * y, fast_mul_f16_strided, half, 1)
|
||||||
|
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
|
||||||
|
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
|
||||||
|
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
|
||||||
|
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
|
||||||
|
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
|
||||||
|
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
|
||||||
|
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
|
||||||
|
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
|
||||||
|
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
|
||||||
|
ARGMAX(fast_argmax_u32_strided, uint, 0)
|
||||||
|
|
||||||
SOFTMAX(softmax_f32, float)
|
SOFTMAX(softmax_f32, float)
|
||||||
SOFTMAX(softmax_f16, half)
|
SOFTMAX(softmax_f16, half)
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
|
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
|
||||||
|
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
|
||||||
|
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
|
||||||
|
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
|
||||||
|
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
||||||
|
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||||
SOFTMAX(softmax_bf16, bfloat)
|
SOFTMAX(softmax_bf16, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
@ -574,12 +574,16 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||||
call_reduce_contiguous(
|
let num_dims = 1;
|
||||||
|
let dims = vec![v.len()];
|
||||||
|
let strides = vec![1];
|
||||||
|
call_reduce_strided(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
v.len(),
|
&dims,
|
||||||
|
&strides,
|
||||||
out_length,
|
out_length,
|
||||||
&input,
|
&input,
|
||||||
0,
|
0,
|
||||||
@ -623,7 +627,7 @@ fn reduce_sum() {
|
|||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
let out_length = 1;
|
let out_length = 1;
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_f32");
|
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||||
assert_eq!(approx(results, 4), vec![21.0]);
|
assert_eq!(approx(results, 4), vec![21.0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -632,7 +636,7 @@ fn reduce_sum2() {
|
|||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
let out_length = 2;
|
let out_length = 2;
|
||||||
|
|
||||||
let results = run_reduce(&v, out_length, "fast_sum_f32");
|
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
|
||||||
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user