Add the fill kernel and use it for 'ones'.

This commit is contained in:
laurent
2023-06-22 08:33:32 +01:00
parent fc26bab3ed
commit 0a758ffa05
3 changed files with 56 additions and 4 deletions

View File

@ -23,6 +23,18 @@ extern "C" __global__ void affine_f32(
} }
"#; "#;
const FILL_CU: &str = r#"
template<typename T>
__device__ void fill_with(T *buf, T value, const size_t numel) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
buf[i] = value;
}
}
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
"#;
impl CudaDevice { impl CudaDevice {
pub(crate) fn new(ordinal: usize) -> Result<Self> { pub(crate) fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal)?; let device = cudarc::driver::CudaDevice::new(ordinal)?;
@ -47,6 +59,45 @@ impl CudaDevice {
} }
} }
pub(crate) fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
let elem_count = shape.elem_count();
let dev = &self.0;
match dtype {
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f32>(elem_count) }?;
let module_name = "fill_f32";
if !dev.has_func(module_name, module_name) {
let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
dev.load_ptx(ptx, module_name, &[module_name])?;
}
let fwd_fn = dev.get_func(module_name, module_name).unwrap();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let params = (&data, v as f32, elem_count);
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(CudaStorage::F32(data))
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { dev.alloc::<f64>(elem_count) }?;
let module_name = "fill_f64";
if !dev.has_func(module_name, module_name) {
let ptx = cudarc::nvrtc::compile_ptx(FILL_CU).unwrap();
dev.load_ptx(ptx, module_name, &[module_name])?;
}
let fwd_fn = dev.get_func(module_name, module_name).unwrap();
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
let params = (&data, v, elem_count);
unsafe { fwd_fn.launch(cfg, params) }?;
Ok(CudaStorage::F64(data))
}
}
}
pub(crate) fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
self.const_impl(1., shape, dtype)
}
pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> { pub(crate) fn cuda_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
match storage { match storage {
CpuStorage::F32(storage) => { CpuStorage::F32(storage) => {

View File

@ -82,10 +82,7 @@ impl Device {
Ok(Storage::Cpu(storage)) Ok(Storage::Cpu(storage))
} }
Device::Cuda(device) => { Device::Cuda(device) => {
// TODO: Instead of allocating memory on the host and transfering it, let storage = device.ones_impl(shape, dtype)?;
// allocate some zeros on the device and use a shader to set them to 1.
let storage = CpuStorage::ones_impl(shape, dtype);
let storage = device.cuda_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage)) Ok(Storage::Cuda(storage))
} }
} }

View File

@ -25,6 +25,10 @@ impl CudaDevice {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
pub(crate) fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> { pub(crate) fn cuda_from_cpu_storage(&self, _: &CpuStorage) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }