Support for UG kernels. (#2579)

* Support for UG kernels.

* Add a dedicated test.
This commit is contained in:
Laurent Mazare
2024-10-27 13:37:19 +01:00
committed by GitHub
parent 37e0ab8c64
commit 594d984f9c
8 changed files with 139 additions and 2 deletions

View File

@ -51,6 +51,27 @@ impl CudaDevice {
self.device.clone()
}
pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
}
pub fn id(&self) -> DeviceId {
self.id
}