Add an enum for scalar values. (#2909)

* Add a scalar enum type.

* Add a bit more to the scalar type.

* Small tweak.

* More scalar usage.
This commit is contained in:
Laurent Mazare
2025-04-18 22:13:38 +02:00
committed by GitHub
parent ce5f8dd129
commit 9dbaf958dc
10 changed files with 150 additions and 55 deletions

View File

@ -313,6 +313,46 @@ impl MetalDevice {
.map_err(MetalError::from)?;
Ok(())
}
pub(crate) fn const_impl<T: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
&self,
v: T,
shape: &crate::Shape,
) -> Result<super::MetalStorage> {
use crate::backend::BackendDevice;
let dtype = T::DTYPE;
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
v,
)
.map_err(MetalError::from)?;
Ok(super::MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {