Update to cudarc 0.14 (breaking change). (#2858)

* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
This commit is contained in:
Laurent Mazare
2025-04-03 09:12:19 +02:00
committed by GitHub
parent d6db305829
commit d9904a3baf
18 changed files with 924 additions and 590 deletions

View File

@ -396,7 +396,10 @@ impl UgIOp1 {
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
Ok(Self {
name,
func: func.into_cuda_function(),
})
}
#[cfg(feature = "metal")]
{
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;
use cudarc::driver::PushKernelArg;
let elem_count = layout.shape().elem_count();
let stream = sto.device.cuda_stream();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
let mut builder = stream.launch_builder(&self.func);
builder.arg(&sto);
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
}