mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the const-set op. (#2910)
* Add the const-set op. * Cuda implementation. * Bugfix. * Metal cleanup. * Add the metal kernels. * Add some testing. * Finish the metal implementation. * Bump the version.
This commit is contained in:
@ -1,9 +1,8 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::scalar::Scalar;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use cudarc::driver::CudaFunction;
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -189,94 +188,6 @@ impl CudaDevice {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: Scalar, shape: &Shape) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match v {
|
||||
Scalar::U8(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
Scalar::U32(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
Scalar::I64(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
Scalar::BF16(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
Scalar::F16(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
Scalar::F32(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
Scalar::F64(v) => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
@ -499,10 +410,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(Scalar::one(dtype), shape)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
|
@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::scalar::Scalar {
|
||||
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
|
||||
use crate::scalar::Scalar;
|
||||
match self {
|
||||
Scalar::U8(v) => builder.arg(v),
|
||||
Scalar::U32(v) => builder.arg(v),
|
||||
Scalar::I64(v) => builder.arg(v),
|
||||
Scalar::F32(v) => builder.arg(v),
|
||||
Scalar::F64(v) => builder.arg(v),
|
||||
Scalar::F16(v) => builder.arg(v),
|
||||
Scalar::BF16(v) => builder.arg(v),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
@ -1235,6 +1250,36 @@ impl BackendStorage for CudaStorage {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src_o = layout.start_offset();
|
||||
let ((src, _guard_src), kernel_name) = match &mut self.slice {
|
||||
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
|
||||
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
|
||||
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
|
||||
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
|
||||
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
|
||||
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
|
||||
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
|
||||
};
|
||||
|
||||
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el_count);
|
||||
barg!(builder, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
s.builder_arg(&mut builder);
|
||||
barg!(builder, src);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
|
Reference in New Issue
Block a user