mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the kernels.
This commit is contained in:
@ -105,6 +105,10 @@ impl CudaDevice {
|
||||
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let data = self.alloc_zeros::<u8>(elem_count)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
@ -136,6 +140,14 @@ impl CudaDevice {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }?;
|
||||
@ -189,6 +201,10 @@ impl CudaDevice {
|
||||
|
||||
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage)?;
|
||||
CudaStorageSlice::U32(data)
|
||||
@ -238,6 +254,7 @@ impl CudaDevice {
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CudaStorageSlice {
|
||||
U8(CudaSlice<u8>),
|
||||
U32(CudaSlice<u32>),
|
||||
BF16(CudaSlice<bf16>),
|
||||
F16(CudaSlice<f16>),
|
||||
@ -256,6 +273,7 @@ trait Map1 {
|
||||
|
||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||
let out = match s {
|
||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||
@ -278,6 +296,7 @@ trait Map2 {
|
||||
|
||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||
let out = match (s1, s2) {
|
||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||
@ -596,6 +615,7 @@ impl CudaStorage {
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self.slice {
|
||||
CudaStorageSlice::U8(_) => DType::U8,
|
||||
CudaStorageSlice::U32(_) => DType::U32,
|
||||
CudaStorageSlice::BF16(_) => DType::BF16,
|
||||
CudaStorageSlice::F16(_) => DType::F16,
|
||||
@ -621,6 +641,7 @@ impl CudaStorage {
|
||||
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
||||
// is used.
|
||||
let inp = match &self.slice {
|
||||
CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||
@ -632,6 +653,12 @@ impl CudaStorage {
|
||||
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
||||
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
let out = unsafe { dev.alloc::<u8>(el) }?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
CudaStorageSlice::U8(out)
|
||||
}
|
||||
DType::U32 => {
|
||||
let out = unsafe { dev.alloc::<u32>(el) }?;
|
||||
let params = (el, dims.len(), &ds, *inp, &out);
|
||||
@ -706,6 +733,11 @@ impl CudaStorage {
|
||||
|
||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match &self.slice {
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
Ok(CpuStorage::U8(cpu_storage))
|
||||
}
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
let dev = slice.device();
|
||||
let cpu_storage = dev.dtoh_sync_copy(slice)?;
|
||||
@ -857,6 +889,18 @@ impl CudaStorage {
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
let func = dev.get_or_load_func("ucopy_u8", kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let params = (el_count, dims.len(), &ds, &src, &mut dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||
if src_l.is_contiguous() {
|
||||
|
Reference in New Issue
Block a user