mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00

* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
79 lines
1.2 KiB
Rust
79 lines
1.2 KiB
Rust
mod ptx;
|
|
|
|
#[repr(u32)]
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
|
pub enum Id {
|
|
Affine,
|
|
Binary,
|
|
Cast,
|
|
Conv,
|
|
Fill,
|
|
Indexing,
|
|
Quantized,
|
|
Reduce,
|
|
Sort,
|
|
Ternary,
|
|
Unary,
|
|
}
|
|
|
|
pub const ALL_IDS: [Id; 11] = [
|
|
Id::Affine,
|
|
Id::Binary,
|
|
Id::Cast,
|
|
Id::Conv,
|
|
Id::Fill,
|
|
Id::Indexing,
|
|
Id::Quantized,
|
|
Id::Reduce,
|
|
Id::Sort,
|
|
Id::Ternary,
|
|
Id::Unary,
|
|
];
|
|
|
|
pub struct Module {
|
|
index: usize,
|
|
ptx: &'static str,
|
|
}
|
|
|
|
impl Module {
|
|
pub fn index(&self) -> usize {
|
|
self.index
|
|
}
|
|
|
|
pub fn ptx(&self) -> &'static str {
|
|
self.ptx
|
|
}
|
|
}
|
|
|
|
const fn module_index(id: Id) -> usize {
|
|
let mut i = 0;
|
|
while i < ALL_IDS.len() {
|
|
if ALL_IDS[i] as u32 == id as u32 {
|
|
return i;
|
|
}
|
|
i += 1;
|
|
}
|
|
panic!("id not found")
|
|
}
|
|
|
|
macro_rules! mdl {
|
|
($cst:ident, $id:ident) => {
|
|
pub const $cst: Module = Module {
|
|
index: module_index(Id::$id),
|
|
ptx: ptx::$cst,
|
|
};
|
|
};
|
|
}
|
|
|
|
mdl!(AFFINE, Affine);
|
|
mdl!(BINARY, Binary);
|
|
mdl!(CAST, Cast);
|
|
mdl!(CONV, Conv);
|
|
mdl!(FILL, Fill);
|
|
mdl!(INDEXING, Indexing);
|
|
mdl!(QUANTIZED, Quantized);
|
|
mdl!(REDUCE, Reduce);
|
|
mdl!(SORT, Sort);
|
|
mdl!(TERNARY, Ternary);
|
|
mdl!(UNARY, Unary);
|