Update to cudarc 0.14 (breaking change). (#2858)

* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
This commit is contained in:
Laurent Mazare
2025-04-03 09:12:19 +02:00
committed by GitHub
parent d6db305829
commit d9904a3baf
18 changed files with 924 additions and 590 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.8.4"
version = "0.9.0-alpha.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -7,5 +7,5 @@ fn main() {
let builder = bindgen_cuda::Builder::default();
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write("src/lib.rs").unwrap();
bindings.write("src/ptx.rs").unwrap();
}

View File

@ -1,11 +1,78 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
mod ptx;
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Id {
Affine,
Binary,
Cast,
Conv,
Fill,
Indexing,
Quantized,
Reduce,
Sort,
Ternary,
Unary,
}
pub const ALL_IDS: [Id; 11] = [
Id::Affine,
Id::Binary,
Id::Cast,
Id::Conv,
Id::Fill,
Id::Indexing,
Id::Quantized,
Id::Reduce,
Id::Sort,
Id::Ternary,
Id::Unary,
];
pub struct Module {
index: usize,
ptx: &'static str,
}
impl Module {
pub fn index(&self) -> usize {
self.index
}
pub fn ptx(&self) -> &'static str {
self.ptx
}
}
const fn module_index(id: Id) -> usize {
let mut i = 0;
while i < ALL_IDS.len() {
if ALL_IDS[i] as u32 == id as u32 {
return i;
}
i += 1;
}
panic!("id not found")
}
macro_rules! mdl {
($cst:ident, $id:ident) => {
pub const $cst: Module = Module {
index: module_index(Id::$id),
ptx: ptx::$cst,
};
};
}
mdl!(AFFINE, Affine);
mdl!(BINARY, Binary);
mdl!(CAST, Cast);
mdl!(CONV, Conv);
mdl!(FILL, Fill);
mdl!(INDEXING, Indexing);
mdl!(QUANTIZED, Quantized);
mdl!(REDUCE, Reduce);
mdl!(SORT, Sort);
mdl!(TERNARY, Ternary);
mdl!(UNARY, Unary);

11
candle-kernels/src/ptx.rs Normal file
View File

@ -0,0 +1,11 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));