mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14. * Adapt a couple more things. * And a couple more fixes. * More tweaks. * And a couple more fixes. * Bump the major version number. * Proper module system for the cuda kernels. * Proper ptx loading. * Launch the sort kernel. * Custom op. * Start using the builder pattern. * More builder. * More builder. * Get candle-core to compile. * Get the tests to pass. * Get candle-nn to work too. * Support for custom cuda functions. * cudnn fixes. * Get flash attn to run. * Switch the crate versions to be alpha. * Bump the ug dependency.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.8.4"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -88,6 +88,7 @@ impl FlashAttn {
|
||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||
}
|
||||
|
||||
let stream = dev.cuda_stream();
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
@ -114,7 +115,9 @@ impl FlashAttn {
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
// Dropping the guard here doesn't seem very safe.
|
||||
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||
ptr as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
@ -161,17 +164,17 @@ impl FlashAttn {
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
||||
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||
ffi::run_mha(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
q_ptr as *const core::ffi::c_void,
|
||||
k_ptr as *const core::ffi::c_void,
|
||||
v_ptr as *const core::ffi::c_void,
|
||||
dst_ptr as *const core::ffi::c_void,
|
||||
softmax_lse_ptr as *const core::ffi::c_void,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||
@ -550,6 +553,7 @@ impl FlashAttnVarLen {
|
||||
|
||||
let batch_size = nseqlens_q - 1;
|
||||
|
||||
let stream = dev.cuda_stream();
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
@ -576,7 +580,9 @@ impl FlashAttnVarLen {
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
// Dropping the guard here doesn't seem very safe.
|
||||
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||
ptr as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
@ -621,22 +627,22 @@ impl FlashAttnVarLen {
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
||||
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
|
||||
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
|
||||
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
|
||||
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
|
||||
ffi::run_mha(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
||||
q_ptr as *const core::ffi::c_void,
|
||||
k_ptr as *const core::ffi::c_void,
|
||||
v_ptr as *const core::ffi::c_void,
|
||||
dst_ptr as *const core::ffi::c_void,
|
||||
softmax_lse_ptr as *const core::ffi::c_void,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
|
||||
/* q_batch_stride */ 0,
|
||||
/* k_batch_stride */ 0,
|
||||
/* v_batch_stride */ 0,
|
||||
|
Reference in New Issue
Block a user