mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.8.4"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -7,5 +7,5 @@ fn main() {
|
||||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write("src/lib.rs").unwrap();
|
||||
bindings.write("src/ptx.rs").unwrap();
|
||||
}
|
||||
|
@ -1,11 +1,78 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
mod ptx;
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Id {
|
||||
Affine,
|
||||
Binary,
|
||||
Cast,
|
||||
Conv,
|
||||
Fill,
|
||||
Indexing,
|
||||
Quantized,
|
||||
Reduce,
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
}
|
||||
|
||||
pub const ALL_IDS: [Id; 11] = [
|
||||
Id::Affine,
|
||||
Id::Binary,
|
||||
Id::Cast,
|
||||
Id::Conv,
|
||||
Id::Fill,
|
||||
Id::Indexing,
|
||||
Id::Quantized,
|
||||
Id::Reduce,
|
||||
Id::Sort,
|
||||
Id::Ternary,
|
||||
Id::Unary,
|
||||
];
|
||||
|
||||
pub struct Module {
|
||||
index: usize,
|
||||
ptx: &'static str,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
pub fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
pub fn ptx(&self) -> &'static str {
|
||||
self.ptx
|
||||
}
|
||||
}
|
||||
|
||||
const fn module_index(id: Id) -> usize {
|
||||
let mut i = 0;
|
||||
while i < ALL_IDS.len() {
|
||||
if ALL_IDS[i] as u32 == id as u32 {
|
||||
return i;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
panic!("id not found")
|
||||
}
|
||||
|
||||
macro_rules! mdl {
|
||||
($cst:ident, $id:ident) => {
|
||||
pub const $cst: Module = Module {
|
||||
index: module_index(Id::$id),
|
||||
ptx: ptx::$cst,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
mdl!(AFFINE, Affine);
|
||||
mdl!(BINARY, Binary);
|
||||
mdl!(CAST, Cast);
|
||||
mdl!(CONV, Conv);
|
||||
mdl!(FILL, Fill);
|
||||
mdl!(INDEXING, Indexing);
|
||||
mdl!(QUANTIZED, Quantized);
|
||||
mdl!(REDUCE, Reduce);
|
||||
mdl!(SORT, Sort);
|
||||
mdl!(TERNARY, Ternary);
|
||||
mdl!(UNARY, Unary);
|
||||
|
11
candle-kernels/src/ptx.rs
Normal file
11
candle-kernels/src/ptx.rs
Normal file
@ -0,0 +1,11 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
Reference in New Issue
Block a user