Add the const-set op. (#2910)

* Add the const-set op.

* Cuda implementation.

* Bugfix.

* Metal cleanup.

* Add the metal kernels.

* Add some testing.

* Finish the metal implementation.

* Bump the version.
This commit is contained in:
Laurent Mazare
2025-04-19 10:07:02 +02:00
committed by GitHub
parent b2904a830b
commit a4c56a958e
20 changed files with 414 additions and 209 deletions

View File

@ -1,9 +1,8 @@
use crate::backend::BackendDevice;
use crate::scalar::Scalar;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
use cudarc::driver::CudaFunction;
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -189,94 +188,6 @@ impl CudaDevice {
self.id
}
fn const_impl(&self, v: Scalar, shape: &Shape) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match v {
Scalar::U8(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count)? };
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U8(data)
}
Scalar::U32(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count)? };
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U32(data)
}
Scalar::I64(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count)? };
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::I64(data)
}
Scalar::BF16(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::BF16(data)
}
Scalar::F16(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count)? };
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F16(data)
}
Scalar::F32(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count)? };
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F32(data)
}
Scalar::F64(v) => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F64(data)
}
};
Ok(CudaStorage {
slice,
device: self.clone(),
})
}
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
@ -499,10 +410,6 @@ impl BackendDevice for CudaDevice {
})
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(Scalar::one(dtype), shape)
}
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
let elem_count = shape.elem_count();
let slice = match dtype {

View File

@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
}
}
impl crate::scalar::Scalar {
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
use crate::scalar::Scalar;
match self {
Scalar::U8(v) => builder.arg(v),
Scalar::U32(v) => builder.arg(v),
Scalar::I64(v) => builder.arg(v),
Scalar::F32(v) => builder.arg(v),
Scalar::F64(v) => builder.arg(v),
Scalar::F16(v) => builder.arg(v),
Scalar::BF16(v) => builder.arg(v),
};
}
}
impl SlicePtrOrNull<usize> {
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
let ds = if l.is_contiguous() {
@ -1235,6 +1250,36 @@ impl BackendStorage for CudaStorage {
&self.device
}
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
let dev = &self.device;
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src_o = layout.start_offset();
let ((src, _guard_src), kernel_name) = match &mut self.slice {
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
};
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
let mut builder = func.builder();
barg!(builder, el_count);
barg!(builder, dims.len());
ds.builder_arg(&mut builder);
s.builder_arg(&mut builder);
barg!(builder, src);
// SAFETY: ffi.
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let shape = layout.shape();
let dims = shape.dims();