mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add an enum for scalar values. (#2909)
* Add a scalar enum type. * Add a bit more to the scalar type. * Small tweak. * More scalar usage.
This commit is contained in:
@ -313,6 +313,46 @@ impl MetalDevice {
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn const_impl<T: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
|
||||
&self,
|
||||
v: T,
|
||||
shape: &crate::Shape,
|
||||
) -> Result<super::MetalStorage> {
|
||||
use crate::backend::BackendDevice;
|
||||
let dtype = T::DTYPE;
|
||||
let name = match dtype {
|
||||
DType::U8 => "fill_u8",
|
||||
DType::U32 => "fill_u32",
|
||||
DType::I64 => "fill_i64",
|
||||
DType::F16 => "fill_f16",
|
||||
DType::BF16 => "fill_bf16",
|
||||
DType::F32 => "fill_f32",
|
||||
DType::F64 => {
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
return self.storage_from_cpu_storage(&cpu_storage);
|
||||
}
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_const_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
v,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(super::MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
|
Reference in New Issue
Block a user