mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Lots of updates including some stack of command buffers.
This commit is contained in:
@ -29,9 +29,7 @@ kernel void FN_NAME( \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[id] * m + a; \
|
||||
output[id] = TYPENAME(float(input[id]) * mul + add); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
@ -47,15 +45,80 @@ kernel void FN_NAME##_strided( \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME m = TYPENAME(mul); \
|
||||
const TYPENAME a = TYPENAME(add); \
|
||||
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
|
||||
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
|
||||
}
|
||||
|
||||
#define POWF(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \
|
||||
}
|
||||
|
||||
#define ELU(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[id]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant float &mul, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint id [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (id >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
} \
|
||||
|
||||
|
||||
AFFINE(affine_float, float)
|
||||
AFFINE(affine_half, half)
|
||||
POWF(powf_float, float)
|
||||
POWF(powf_half, half)
|
||||
ELU(elu_float, float)
|
||||
ELU(elu_half, half)
|
||||
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
AFFINE(affine_bfloat, bfloat);
|
||||
POWF(powf_bfloat, bfloat);
|
||||
ELU(elu_bfloat, bfloat);
|
||||
#endif
|
||||
|
@ -153,7 +153,7 @@ macro_rules! ops{
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf);
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
@ -616,6 +616,130 @@ pub fn call_affine_strided(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_powf(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_powf_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
size,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_stride,
|
||||
mul,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_elu(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_elu_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
|
||||
let size: usize = shape.iter().product();
|
||||
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
size,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_stride,
|
||||
mul,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_where_cond_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
|
@ -18,7 +18,7 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 1024;
|
||||
constant int THREADGROUP_SIZE = 2048;
|
||||
|
||||
# define REDUCE(FN, NAME, T) \
|
||||
kernel void NAME( \
|
||||
|
@ -69,7 +69,7 @@ kernel void FN_NAME( \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[thread_position_in_grid]))); \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
@ -83,7 +83,7 @@ kernel void FN_NAME_STRIDED( \
|
||||
if (thread_position_in_grid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
|
||||
output[thread_position_in_grid] = TYPENAME(FN(float(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)]))); \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
@ -107,6 +107,7 @@ UNARY_OP(floor)
|
||||
UNARY_OP(round)
|
||||
UNARY_OP(gelu_erf)
|
||||
UNARY_OP(erf)
|
||||
UNARY_OP(tanh)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -126,6 +127,7 @@ BFLOAT_UNARY_OP(floor)
|
||||
BFLOAT_UNARY_OP(round)
|
||||
BFLOAT_UNARY_OP(gelu_erf)
|
||||
BFLOAT_UNARY_OP(erf)
|
||||
BFLOAT_UNARY_OP(tanh)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user