Efficient implementation of Tensor::ones() for metal (#2512)

* WIP: hopefully better const impl

* with GPU

* More tests on

* Reverting primitive for

* Incorporating review changes - added check elem count check in kerner, using  for call strategy

* rustfmt ran
This commit is contained in:
Anubhab Bandyopadhyay
2024-10-01 22:41:59 +05:30
committed by GitHub
parent def4c6cdee
commit a2bcc227df
5 changed files with 194 additions and 4 deletions

View File

@ -14,6 +14,7 @@ const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
@ -31,6 +32,7 @@ pub enum Source {
Binary,
Cast,
Conv,
Fill,
Gemm,
Indexing,
Mfa,
@ -196,6 +198,7 @@ impl Kernels {
Source::Binary => BINARY,
Source::Cast => CAST,
Source::Conv => CONV,
Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
@ -2357,5 +2360,30 @@ pub fn call_mlx_gemm(
Ok(())
}
pub fn call_const_fill(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (output, v, length));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[cfg(test)]
mod tests;