Add the const-set op. (#2910)

* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
This commit is contained in:
Laurent Mazare
2025-04-19 10:07:02 +02:00
committed by GitHub
parent b2904a830b
commit a4c56a958e
20 changed files with 414 additions and 209 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.9.0-alpha.4"
version = "0.9.0-alpha.5"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -161,7 +161,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu, sign, sigmoid
tanh, recip, silu, sign, sigmoid, const_set
);
}
pub mod binary {
@ -419,6 +419,82 @@ pub fn call_copy2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous_tiled(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous_tiled::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
let tiles = length.div_ceil(tile_size);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: impl EncoderParam,
strides: &[usize],
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous_tiled(
device: &Device,

View File

@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
#define TILE_SIZE 2
#define CONST_SET(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = input; \
} \
kernel void FN_NAME##_##strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
} \
kernel void FN_NAME##_##tiled( \
constant size_t &dim, \
constant TYPENAME &input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
for (uint i = 0; i < TILE_SIZE; i++) { \
const uint idx = tid * TILE_SIZE + i; \
output[idx] = input; \
} \
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
COPY2D(copy2d_u8, uint8_t)
COPY2D(copy2d_u32, uint32_t)
CONST_SET(float, const_set_f32)
CONST_SET(half, const_set_f16)
CONST_SET(uint8_t, const_set_u8)
CONST_SET(uint32_t, const_set_u32)
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
#if __METAL_VERSION__ >= 220
UNARY(id, int64_t, copy_i64, copy_i64_strided)
COPY2D(copy2d_i64, int64_t)
CONST_SET(int64_t, const_set_i64)
#endif
#if defined(__HAVE_BFLOAT__)
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
COPY2D(copy2d_bf16, bfloat)
CONST_SET(bfloat, const_set_bf16)
#endif