Add an enum for scalar values. (#2909)

* Add a scalar enum type.

* Add a bit more to the scalar type.

* Small tweak.

* More scalar usage.
This commit is contained in:
Laurent Mazare
2025-04-18 22:13:38 +02:00
committed by GitHub
parent ce5f8dd129
commit 9dbaf958dc
10 changed files with 150 additions and 55 deletions

View File

@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"] }
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -4,20 +4,20 @@ using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant T &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
out[tid] = value;
}
#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant T &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \

View File

@ -2570,7 +2570,7 @@ pub fn call_const_fill(
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
v: impl EncoderParam,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();

View File

@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
#[test]
fn const_fill() {
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
@ -2357,11 +2357,15 @@ fn const_fill() {
command_buffer.wait_until_completed();
read_to_vec::<T>(&buffer, len)
}
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
name: &'static str,
f: F,
) {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);
let value = f(value);
let v = constant_fill::<T>(name, len, value);
assert_eq!(v, vec![f(value); len])
assert_eq!(v, vec![value; len])
}
test::<u8, _>("fill_u8", |v| v as u8);
test::<u32, _>("fill_u32", |v| v as u32);

View File

@ -88,9 +88,13 @@ primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u8);
primitive!(u32);
primitive!(u64);
primitive!(f32);
primitive!(f64);
primitive!(half::bf16);
primitive!(half::f16);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,