diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 7858e542..5706a2e6 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -23,6 +23,18 @@ extern "C" __global__ void affine_f32( } "#; +const FILL_CU: &str = r#" +template +__device__ void fill_with(T *buf, T value, const size_t numel) { + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { + buf[i] = value; + } +} +extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } +"#; + impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result { let device = cudarc::driver::CudaDevice::new(ordinal)?; @@ -47,6 +59,45 @@ impl CudaDevice { } } + pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let dev = &self.0; + match dtype { + DType::F32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { dev.alloc::(elem_count) }?; + let module_name = "fill_f32"; + if !dev.has_func(module_name, module_name) { + let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap(); + dev.load_ptx(ptx, module_name, &[module_name])?; + } + let fwd_fn = dev.get_func(module_name, module_name).unwrap(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let params = (&data, v as f32, elem_count); + unsafe { fwd_fn.launch(cfg, params) }?; + Ok(CudaStorage::F32(data)) + } + DType::F64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { dev.alloc::(elem_count) }?; + let module_name = "fill_f64"; + if !dev.has_func(module_name, module_name) { + let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap(); + dev.load_ptx(ptx, module_name, &[module_name])?; + } + let fwd_fn = dev.get_func(module_name, module_name).unwrap(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let params = (&data, v, elem_count); + unsafe { fwd_fn.launch(cfg, params) }?; + Ok(CudaStorage::F64(data)) + } + } + } + + pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + self.const_impl(1., shape, dtype) + } + pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result { match storage { CpuStorage::F32(storage) => { diff --git a/src/device.rs b/src/device.rs index e522cd42..ab7bad26 100644 --- a/src/device.rs +++ b/src/device.rs @@ -82,10 +82,7 @@ impl Device { Ok(Storage::Cpu(storage)) } Device::Cuda(device) => { - // TODO: Instead of allocating memory on the host and transfering it, - // allocate some zeros on the device and use a shader to set them to 1. - let storage = CpuStorage::ones_impl(shape, dtype); - let storage = device.cuda_from_cpu_storage(&storage)?; + let storage = device.ones_impl(shape, dtype)?; Ok(Storage::Cuda(storage)) } } diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 85b5f598..2eb393c1 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -25,6 +25,10 @@ impl CudaDevice { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result { Err(Error::NotCompiledWithCudaSupport) }