mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Efficient implementation of Tensor::ones()
for metal
(#2512)
* WIP: hopefully better const impl * with GPU * More tests on * Reverting primitive for * Incorporating review changes - added check elem count check in kerner, using for call strategy * rustfmt ran
This commit is contained in:

committed by
GitHub

parent
def4c6cdee
commit
a2bcc227df
39
candle-metal-kernels/src/fill.metal
Normal file
39
candle-metal-kernels/src/fill.metal
Normal file
@ -0,0 +1,39 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template<typename T> METAL_FUNC void fill_with(
|
||||
device T *out,
|
||||
constant float &value,
|
||||
constant size_t &numel,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= numel) {
|
||||
return;
|
||||
}
|
||||
out[tid] = static_cast<T>(value);
|
||||
}
|
||||
|
||||
#define FILL_OP(NAME, T) \
|
||||
kernel void fill_##NAME( \
|
||||
device T *out, \
|
||||
constant float &value, \
|
||||
constant size_t &numel, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
fill_with<T>(out, value, numel, tid); \
|
||||
} \
|
||||
|
||||
|
||||
#define FILL_OPS(NAME, T) \
|
||||
FILL_OP(NAME, T) \
|
||||
|
||||
FILL_OPS(u8, uchar)
|
||||
FILL_OPS(u32, uint)
|
||||
FILL_OPS(i64, long)
|
||||
FILL_OPS(f16, half)
|
||||
FILL_OPS(f32, float)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
FILL_OPS(bf16, bfloat)
|
||||
#endif
|
@ -14,6 +14,7 @@ const AFFINE: &str = include_str!("affine.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const FILL: &str = include_str!("fill.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
@ -31,6 +32,7 @@ pub enum Source {
|
||||
Binary,
|
||||
Cast,
|
||||
Conv,
|
||||
Fill,
|
||||
Gemm,
|
||||
Indexing,
|
||||
Mfa,
|
||||
@ -196,6 +198,7 @@ impl Kernels {
|
||||
Source::Binary => BINARY,
|
||||
Source::Cast => CAST,
|
||||
Source::Conv => CONV,
|
||||
Source::Fill => FILL,
|
||||
Source::Gemm => MLX_GEMM,
|
||||
Source::Indexing => INDEXING,
|
||||
Source::Quantized => QUANTIZED,
|
||||
@ -2357,5 +2360,30 @@ pub fn call_mlx_gemm(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn call_const_fill(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
length: usize,
|
||||
output: &Buffer,
|
||||
v: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (output, v, length));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -1,6 +1,7 @@
|
||||
use super::*;
|
||||
use half::{bf16, f16};
|
||||
use metal::MTLResourceOptions;
|
||||
use rand::Rng;
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
@ -2307,3 +2308,67 @@ fn conv_transpose1d_u32() {
|
||||
let expected = vec![1, 4, 10, 20, 25, 24, 16];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let buffer = dev.new_buffer(
|
||||
(len * std::mem::size_of::<T>()) as u64,
|
||||
MTLResourceOptions::StorageModePrivate,
|
||||
);
|
||||
|
||||
call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn const_fill() {
|
||||
let fills = [
|
||||
"fill_u8",
|
||||
"fill_u32",
|
||||
"fill_i64",
|
||||
"fill_f16",
|
||||
"fill_bf16",
|
||||
"fill_f32",
|
||||
];
|
||||
|
||||
for name in fills {
|
||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||
|
||||
match name {
|
||||
"fill_u8" => {
|
||||
let v = constant_fill::<u8>(name, len, value);
|
||||
assert_eq!(v, vec![value as u8; len])
|
||||
}
|
||||
"fill_u32" => {
|
||||
let v = constant_fill::<u32>(name, len, value);
|
||||
assert_eq!(v, vec![value as u32; len])
|
||||
}
|
||||
"fill_i64" => {
|
||||
let v = constant_fill::<i64>(name, len, value);
|
||||
assert_eq!(v, vec![value as i64; len])
|
||||
}
|
||||
"fill_f16" => {
|
||||
let v = constant_fill::<f16>(name, len, value);
|
||||
assert_eq!(v, vec![f16::from_f32(value); len])
|
||||
}
|
||||
"fill_bf16" => {
|
||||
let v = constant_fill::<bf16>(name, len, value);
|
||||
assert_eq!(v, vec![bf16::from_f32(value); len])
|
||||
}
|
||||
"fill_f32" => {
|
||||
let v = constant_fill::<f32>(name, len, value);
|
||||
assert_eq!(v, vec![value; len])
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user