diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 5706a2e6..8234eeac 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -1,7 +1,23 @@ -use crate::{CpuStorage, DType, Error, Result, Shape}; -use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; +use crate::{CpuStorage, DType, Shape}; +use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}; -pub type CudaError = cudarc::driver::DriverError; +/// cudarc related errors +#[derive(thiserror::Error, Debug)] +pub enum CudaError { + #[error(transparent)] + Cuda(#[from] cudarc::driver::DriverError), + + #[error(transparent)] + Compiler(#[from] cudarc::nvrtc::CompileError), + + #[error("{op} only supports contiguous tensors")] + RequiresContiguous { op: &'static str }, + + #[error("missing kernel '{module_name}'")] + MissingKernel { module_name: &'static str }, +} + +type Result = std::result::Result; #[derive(Debug, Clone)] pub struct CudaDevice(std::sync::Arc); @@ -21,6 +37,20 @@ extern "C" __global__ void affine_f32( } y[i] = x[i] * mul + add; } + +extern "C" __global__ void affine_f64( + const size_t numel, + const double *x, + double *y, + const double mul, + const double add +) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= numel) { + return; + } + y[i] = x[i] * mul + add; +} "#; const FILL_CU: &str = r#" @@ -61,34 +91,23 @@ impl CudaDevice { pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = &self.0; match dtype { DType::F32 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { dev.alloc::(elem_count) }?; - let module_name = "fill_f32"; - if !dev.has_func(module_name, module_name) { - let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap(); - dev.load_ptx(ptx, module_name, &[module_name])?; - } - let fwd_fn = dev.get_func(module_name, module_name).unwrap(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let func = self.get_or_load_func("fill_f32", FILL_CU)?; let params = (&data, v as f32, elem_count); - unsafe { fwd_fn.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }?; Ok(CudaStorage::F32(data)) } DType::F64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { dev.alloc::(elem_count) }?; - let module_name = "fill_f64"; - if !dev.has_func(module_name, module_name) { - let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap(); - dev.load_ptx(ptx, module_name, &[module_name])?; - } - let fwd_fn = dev.get_func(module_name, module_name).unwrap(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let func = self.get_or_load_func("fill_f64", FILL_CU)?; let params = (&data, v, elem_count); - unsafe { fwd_fn.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }?; Ok(CudaStorage::F64(data)) } } @@ -110,6 +129,23 @@ impl CudaDevice { } } } + + fn get_or_load_func( + &self, + module_name: &'static str, + source: &'static str, + ) -> Result { + let dev = &self.0; + if !dev.has_func(module_name, module_name) { + // TODO: Pre-compile and load rather than compiling here. + let ptx = cudarc::nvrtc::compile_ptx(source)?; + dev.load_ptx(ptx, module_name, &[module_name])?; + } + dev.get_func(module_name, module_name) + // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is + // able to only build the error value if needed. + .ok_or(CudaError::MissingKernel { module_name }) + } } #[derive(Debug, Clone)] @@ -140,30 +176,33 @@ impl CudaStorage { mul: f64, add: f64, ) -> Result { + if !shape.is_contiguous(stride) { + return Err(CudaError::RequiresContiguous { op: "affine" }); + } + + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let dev = self.device(); match self { Self::F32(arg) => { - if !shape.is_contiguous(stride) { - return Err(Error::RequiresContiguous { op: "affine" }); - } - let dev = arg.device(); - let module_name = "affine_f32"; - if !dev.has_func(module_name, module_name) { - let ptx = cudarc::nvrtc::compile_ptx(AFFINE_CU).unwrap(); - dev.load_ptx(ptx, module_name, &[module_name])?; - } - let elem_count = shape.elem_count(); - let fwd_fn = dev.get_func(module_name, module_name).unwrap(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let func = dev.get_or_load_func("affine_f32", AFFINE_CU)?; // SAFETY: if this function returns Ok(..), the kernel has been applied // and has set the initially unset memory. - let out = unsafe { dev.alloc::(elem_count) }?; + let out = unsafe { dev.0.alloc::(elem_count) }?; let params = (elem_count, arg, &out, mul as f32, add as f32); // SAFETY: well, well, well... - unsafe { fwd_fn.launch(cfg, params) }?; + unsafe { func.launch(cfg, params) }?; Ok(Self::F32(out)) } - Self::F64(_) => { - todo!() + Self::F64(arg) => { + let func = dev.get_or_load_func("affine_f64", AFFINE_CU)?; + // SAFETY: if this function returns Ok(..), the kernel has been applied + // and has set the initially unset memory. + let out = unsafe { dev.0.alloc::(elem_count) }?; + let params = (elem_count, arg, &out, mul, add); + // SAFETY: well, well, well... + unsafe { func.launch(cfg, params) }?; + Ok(Self::F64(out)) } } }