Merge pull request #8 from LaurentMazare/fix_cuda

Backport.
This commit is contained in:
Nicolas Patry
2023-06-23 14:27:01 +02:00
committed by GitHub

View File

@ -79,6 +79,10 @@ impl CudaDevice {
pub(crate) fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count)?;
CudaStorageSlice::U32(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
CudaStorageSlice::F32(data)
@ -98,6 +102,14 @@ impl CudaDevice {
let elem_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let slice = match dtype {
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }?;
@ -127,6 +139,10 @@ impl CudaDevice {
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::U32(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage)?;
CudaStorageSlice::F32(data)
@ -159,6 +175,7 @@ impl CudaDevice {
#[derive(Debug)]
enum CudaStorageSlice {
U32(CudaSlice<u32>),
F32(CudaSlice<f32>),
F64(CudaSlice<f64>),
}
@ -205,6 +222,7 @@ fn gemm_config<T>(
impl CudaStorage {
pub fn try_clone(&self) -> Result<Self> {
let slice = match &self.slice {
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
};
@ -214,6 +232,7 @@ impl CudaStorage {
pub fn dtype(&self) -> DType {
match self.slice {
CudaStorageSlice::U32(_) => DType::U32,
CudaStorageSlice::F32(_) => DType::F32,
CudaStorageSlice::F64(_) => DType::F64,
}
@ -236,6 +255,15 @@ impl CudaStorage {
let dev = self.device();
let ds = dev.htod_copy([dims, stride].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(arg) => {
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u32>(el_count) }?;
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
CudaStorageSlice::U32(out)
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
@ -270,6 +298,9 @@ impl CudaStorage {
let dev = &self.device;
let ds = dev.htod_copy([dims, stride].concat())?;
let slice = match &self.slice {
CudaStorageSlice::U32(_arg) => {
todo!("No unary kernels for u32");
}
CudaStorageSlice::F32(arg) => {
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
@ -333,6 +364,11 @@ impl CudaStorage {
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
match &self.slice {
CudaStorageSlice::U32(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
Ok(CpuStorage::U32(cpu_storage))
}
CudaStorageSlice::F32(slice) => {
let dev = slice.device();
let cpu_storage = dev.dtoh_sync_copy(slice)?;
@ -348,9 +384,9 @@ impl CudaStorage {
pub(crate) fn embedding_impl(
&self,
rhs: &Self,
hidden_size: usize,
vocab_size: usize,
_rhs: &Self,
_hidden_size: usize,
_vocab_size: usize,
) -> Result<Self> {
todo!("Implement embedding for gpu");
}