Add the const-set op. (#2910)

* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
This commit is contained in:
Laurent Mazare
2025-04-19 10:07:02 +02:00
committed by GitHub
parent b2904a830b
commit a4c56a958e
20 changed files with 414 additions and 209 deletions

View File

@ -313,46 +313,6 @@ impl MetalDevice {
.map_err(MetalError::from)?;
Ok(())
}
pub(crate) fn const_impl<T: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
&self,
v: T,
shape: &crate::Shape,
) -> Result<super::MetalStorage> {
use crate::backend::BackendDevice;
let dtype = T::DTYPE;
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
v,
)
.map_err(MetalError::from)?;
Ok(super::MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}
}
fn buf_size(size: NSUInteger) -> NSUInteger {