mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the const-set op. (#2910)
* Add the const-set op. * Cuda implementation. * Bugfix. * Metal cleanup. * Add the metal kernels. * Add some testing. * Finish the metal implementation. * Bump the version.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.0-alpha.5"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -161,7 +161,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu, sign, sigmoid
|
||||
tanh, recip, silu, sign, sigmoid, const_set
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -419,6 +419,82 @@ pub fn call_copy2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous_tiled(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous_tiled::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let tile_size = 2;
|
||||
let tiles = length.div_ceil(tile_size);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_strided(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
input: impl EncoderParam,
|
||||
strides: &[usize],
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous_tiled(
|
||||
device: &Device,
|
||||
|
@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||
|
||||
#define TILE_SIZE 2
|
||||
|
||||
#define CONST_SET(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##tiled( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
for (uint i = 0; i < TILE_SIZE; i++) { \
|
||||
const uint idx = tid * TILE_SIZE + i; \
|
||||
output[idx] = input; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
|
||||
COPY2D(copy2d_u8, uint8_t)
|
||||
COPY2D(copy2d_u32, uint32_t)
|
||||
|
||||
CONST_SET(float, const_set_f32)
|
||||
CONST_SET(half, const_set_f16)
|
||||
CONST_SET(uint8_t, const_set_u8)
|
||||
CONST_SET(uint32_t, const_set_u32)
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
CONST_SET(int64_t, const_set_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
CONST_SET(bfloat, const_set_bf16)
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user