diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index a9647469..0a4825fa 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -4,9 +4,9 @@ use candle::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?; - let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?; println!("{:?}", x.to_vec1::()?); - let z = (x + y)?; + let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?; + let z = (y * 3.)?; println!("{:?}", z.to_vec1::()?); Ok(()) } diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 5a80df08..06730251 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -1,11 +1,28 @@ use crate::{CpuStorage, DType, Result, Shape}; -use cudarc::driver::CudaSlice; +use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig}; pub(crate) type Error = cudarc::driver::DriverError; #[derive(Debug, Clone)] pub struct CudaDevice(std::sync::Arc); +// TODO: Switch to pre-compiled PTX kernels rather than compiling on the fly. +const AFFINE_CU: &str = r#" +extern "C" __global__ void affine_f32( + const size_t numel, + const float *x, + float *y, + const float mul, + const float add +) { + unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= numel) { + return; + } + y[i] = x[i] * mul + add; +} +"#; + impl CudaDevice { pub(crate) fn new(ordinal: usize) -> Result { let device = cudarc::driver::CudaDevice::new(ordinal)?; @@ -65,6 +82,39 @@ impl CudaStorage { } } + pub(crate) fn affine_impl( + &self, + shape: &Shape, + _stride: &[usize], + mul: f64, + add: f64, + ) -> Result { + match self { + Self::F32(arg) => { + // TODO: Handle the stride. + let dev = arg.device(); + let module_name = "affine_f32"; + if !dev.has_func(module_name, module_name) { + let ptx = cudarc::nvrtc::compile_ptx(AFFINE_CU).unwrap(); + dev.load_ptx(ptx, module_name, &[module_name])?; + } + let elem_count = shape.elem_count(); + let fwd_fn = dev.get_func(module_name, module_name).unwrap(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + // SAFETY: if this function returns Ok(..), the kernel has been applied + // and has set the initially unset memory. + let out = unsafe { dev.alloc::(elem_count) }?; + let params = (elem_count, arg, &out, mul as f32, add as f32); + // SAFETY: well, well, well... + unsafe { fwd_fn.launch(cfg, params) }?; + Ok(Self::F32(out)) + } + Self::F64(_) => { + todo!() + } + } + } + pub(crate) fn to_cpu_storage(&self) -> Result { match self { Self::F32(slice) => { diff --git a/src/storage.rs b/src/storage.rs index 7230104e..573cf945 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -144,7 +144,10 @@ impl Storage { let storage = storage.affine_impl(shape, stride, mul, add)?; Ok(Self::Cpu(storage)) } - Self::Cuda { .. } => todo!(), + Self::Cuda(storage) => { + let storage = storage.affine_impl(shape, stride, mul, add)?; + Ok(Self::Cuda(storage)) + } } }