Add the const-set op. (#2910)

* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
This commit is contained in:
Laurent Mazare
2025-04-19 10:07:02 +02:00
committed by GitHub
parent b2904a830b
commit a4c56a958e
20 changed files with 414 additions and 209 deletions

View File

@ -161,7 +161,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip, silu, sign, sigmoid
tanh, recip, silu, sign, sigmoid, const_set
);
}
pub mod binary {
@ -419,6 +419,82 @@ pub fn call_copy2d(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous_tiled(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous_tiled::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let tile_size = 2;
let tiles = length.div_ceil(tile_size);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_contiguous(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: impl EncoderParam,
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, &output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_const_set_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: impl EncoderParam,
strides: &[usize],
output: BufferOffset,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
let length: usize = shape.iter().product();
let num_dims: usize = shape.len();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_unary_contiguous_tiled(
device: &Device,