Add a first kernel.

This commit is contained in:
laurent
2023-06-21 20:48:22 +01:00
parent fcb4e6b84f
commit 97d9142dee
3 changed files with 57 additions and 4 deletions

View File

@ -4,9 +4,9 @@ use candle::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?;
let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
println!("{:?}", x.to_vec1::<f32>()?);
let z = (x + y)?;
let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
let z = (y * 3.)?;
println!("{:?}", z.to_vec1::<f32>()?);
Ok(())
}

View File

@ -1,11 +1,28 @@
use crate::{CpuStorage, DType, Result, Shape};
use cudarc::driver::CudaSlice;
use cudarc::driver::{CudaSlice, LaunchAsync, LaunchConfig};
pub(crate) type Error = cudarc::driver::DriverError;
#[derive(Debug, Clone)]
pub struct CudaDevice(std::sync::Arc<cudarc::driver::CudaDevice>);
// TODO: Switch to pre-compiled PTX kernels rather than compiling on the fly.
const AFFINE_CU: &str = r#"
extern "C" __global__ void affine_f32(
const size_t numel,
const float *x,
float *y,
const float mul,
const float add
) {
unsigned int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= numel) {
return;
}
y[i] = x[i] * mul + add;
}
"#;
impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?;
@ -65,6 +82,39 @@ impl CudaStorage {
}
}
pub(crate) fn affine_impl(
&self,
shape: &Shape,
_stride: &[usize],
mul: f64,
add: f64,
) -> Result<Self> {
match self {
Self::F32(arg) => {
// TODO: Handle the stride.
let dev = arg.device();
let module_name = "affine_f32";
if !dev.has_func(module_name, module_name) {
let ptx = cudarc::nvrtc::compile_ptx(AFFINE_CU).unwrap();
dev.load_ptx(ptx, module_name, &[module_name])?;
}
let elem_count = shape.elem_count();
let fwd_fn = dev.get_func(module_name, module_name).unwrap();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
// SAFETY: if this function returns Ok(..), the kernel has been applied
// and has set the initially unset memory.
let out = unsafe { dev.alloc::<f32>(elem_count) }?;
let params = (elem_count, arg, &out, mul as f32, add as f32);
// SAFETY: well, well, well...
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(Self::F32(out))
}
Self::F64(_) => {
todo!()
}
}
}
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self {
Self::F32(slice) => {

View File

@ -144,7 +144,10 @@ impl Storage {
let storage = storage.affine_impl(shape, stride, mul, add)?;
Ok(Self::Cpu(storage))
}
Self::Cuda { .. } => todo!(),
Self::Cuda(storage) => {
let storage = storage.affine_impl(shape, stride, mul, add)?;
Ok(Self::Cuda(storage))
}
}
}