From 8add5a5f49a1f02beb3fff58334aaa5bc31c4516 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 23 Jun 2023 14:17:39 +0200 Subject: [PATCH] Backport. --- src/cuda_backend.rs | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 50b8c7ff..d9958e3b 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -79,6 +79,10 @@ impl CudaDevice { pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); let slice = match dtype { + DType::U32 => { + let data = self.alloc_zeros::(elem_count)?; + CudaStorageSlice::U32(data) + } DType::F32 => { let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::F32(data) @@ -98,6 +102,14 @@ impl CudaDevice { let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let slice = match dtype { + DType::U32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }?; + let func = self.get_or_load_func("fill_u32", kernels::FILL)?; + let params = (&data, v as u32, elem_count); + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U32(data) + } DType::F32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }?; @@ -127,6 +139,10 @@ impl CudaDevice { pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result { let slice = match storage { + CpuStorage::U32(storage) => { + let data = self.htod_sync_copy(storage)?; + CudaStorageSlice::U32(data) + } CpuStorage::F32(storage) => { let data = self.htod_sync_copy(storage)?; CudaStorageSlice::F32(data) @@ -159,6 +175,7 @@ impl CudaDevice { #[derive(Debug)] enum CudaStorageSlice { + U32(CudaSlice), F32(CudaSlice), F64(CudaSlice), } @@ -205,6 +222,7 @@ fn gemm_config( impl CudaStorage { pub fn try_clone(&self) -> Result { let slice = match &self.slice { + CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?), CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?), CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?), }; @@ -214,6 +232,7 @@ impl CudaStorage { pub fn dtype(&self) -> DType { match self.slice { + CudaStorageSlice::U32(_) => DType::U32, CudaStorageSlice::F32(_) => DType::F32, CudaStorageSlice::F64(_) => DType::F64, } @@ -236,6 +255,15 @@ impl CudaStorage { let dev = self.device(); let ds = dev.htod_copy([dims, stride].concat())?; let slice = match &self.slice { + CudaStorageSlice::U32(arg) => { + let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el_count) }?; + let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + CudaStorageSlice::U32(out) + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. @@ -270,6 +298,9 @@ impl CudaStorage { let dev = &self.device; let ds = dev.htod_copy([dims, stride].concat())?; let slice = match &self.slice { + CudaStorageSlice::U32(_arg) => { + todo!("No unary kernels for u32"); + } CudaStorageSlice::F32(arg) => { let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; // SAFETY: Set later by running the kernel. @@ -333,6 +364,11 @@ impl CudaStorage { pub(crate) fn to_cpu_storage(&self) -> Result { match &self.slice { + CudaStorageSlice::U32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice)?; + Ok(CpuStorage::U32(cpu_storage)) + } CudaStorageSlice::F32(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice)?; @@ -348,9 +384,9 @@ impl CudaStorage { pub(crate) fn embedding_impl( &self, - rhs: &Self, - hidden_size: usize, - vocab_size: usize, + _rhs: &Self, + _hidden_size: usize, + _vocab_size: usize, ) -> Result { todo!("Implement embedding for gpu"); }