Support for UG kernels. (#2579)

* Support for UG kernels.

* Add a dedicated test.
This commit is contained in:
Laurent Mazare
2024-10-27 13:37:19 +01:00
committed by GitHub
parent 37e0ab8c64
commit 594d984f9c
8 changed files with 139 additions and 2 deletions

View File

@ -375,3 +375,70 @@ impl Tensor {
)
}
}
pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
}
impl UgIOp1 {
#[allow(unused)]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(feature = "cuda"))]
{
Ok(Self { name })
}
}
}
impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on cuda at the moment")
}
fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on cuda at the moment")
}
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;
let elem_count = layout.shape().elem_count();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}