Lots of updates including some stack of command buffers.

This commit is contained in:
nicolas
2023-12-12 17:41:56 +01:00
parent da0af3cb3e
commit 87dc559817
10 changed files with 537 additions and 117 deletions

View File

@ -29,9 +29,7 @@ kernel void FN_NAME( \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
output[id] = TYPENAME(float(input[id]) * mul + add); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
@ -47,15 +45,80 @@ kernel void FN_NAME##_strided( \
if (id >= dim) { \
return; \
} \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
}
#define POWF(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \
}
#define ELU(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME x = input[id]; \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
POWF(powf_float, float)
POWF(powf_half, half)
ELU(elu_float, float)
ELU(elu_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
POWF(powf_bfloat, bfloat);
ELU(elu_bfloat, bfloat);
#endif

View File

@ -153,7 +153,7 @@ macro_rules! ops{
}
pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
}
pub mod binary {
ops!(add, sub, mul, div);
@ -616,6 +616,130 @@ pub fn call_affine_strided(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
size,
shape.len(),
shape,
input_stride,
mul,
(input, input_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
size: usize,
input: &Buffer,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: &Buffer,
input_stride: &[usize],
input_offset: usize,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
size,
shape.len(),
shape,
input_stride,
mul,
(input, input_offset),
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_where_cond_strided(
device: &Device,
command_buffer: &CommandBufferRef,

View File

@ -18,7 +18,7 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
constant int THREADGROUP_SIZE = 1024;
constant int THREADGROUP_SIZE = 2048;
# define REDUCE(FN, NAME, T) \
kernel void NAME( \

View File

@ -69,7 +69,7 @@ kernel void FN_NAME( \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -83,7 +83,7 @@ kernel void FN_NAME_STRIDED( \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
}
#define UNARY_OP(NAME) \
@ -107,6 +107,7 @@ UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -126,6 +127,7 @@ BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif