Update to cudarc 0.14 (breaking change). (#2858)

* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
This commit is contained in:
Laurent Mazare
2025-04-03 09:12:19 +02:00
committed by GitHub
parent d6db305829
commit d9904a3baf
18 changed files with 924 additions and 590 deletions

View File

@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
};
use candle::cuda_backend::SlicePtrOrNull;
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let params = (el_count, dims.len(), &ds, src, &out);
let mut builder = func.builder();
candle::builder_arg!(builder, el_count, dims.len());
ds.builder_arg(&mut builder);
builder.arg(src);
builder.arg(&out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(out)
}
}
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
use candle::{CudaDevice, WithDType};
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
block_dim: (1, 32, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, n_cols as i32);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
candle::builder_arg!(builder, n_cols as i32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
}
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
l2: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
use candle::{CudaDevice, WithDType};
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
n_cols as i32,
block_size as i32,
self.eps,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
builder.arg(&alpha);
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
}
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
use candle::{CudaDevice, WithDType};
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
let func =
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
&beta,
n_cols as i32,
block_size as i32,
self.eps,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
builder.arg(&alpha);
builder.arg(&beta);
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
}

View File

@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&cos,
&sin,
&dst,
(b * h) as u32,
(t * d) as u32,
d as u32,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
let (b, t, h, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}