mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Properly set the is_bf16 flag. (#738)
This commit is contained in:
@ -26,6 +26,7 @@ impl FlashAttn {
|
|||||||
k_l: &Layout,
|
k_l: &Layout,
|
||||||
v: &candle::CudaStorage,
|
v: &candle::CudaStorage,
|
||||||
v_l: &Layout,
|
v_l: &Layout,
|
||||||
|
is_bf16: bool,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
|
// https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187
|
||||||
let dev = q.device();
|
let dev = q.device();
|
||||||
@ -94,6 +95,7 @@ impl FlashAttn {
|
|||||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||||
|
|
||||||
let causal = if self.causal { 1 } else { 0 };
|
let causal = if self.causal { 1 } else { 0 };
|
||||||
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||||
@ -132,7 +134,7 @@ impl FlashAttn {
|
|||||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||||
/* is_causal */ causal,
|
/* is_causal */ causal,
|
||||||
/* is_bf16 */ 0,
|
/* is_bf16 */ is_bf16,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,8 +170,8 @@ impl candle::CustomOp3 for FlashAttn {
|
|||||||
v_l: &Layout,
|
v_l: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
match q.dtype() {
|
match q.dtype() {
|
||||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
|
||||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
|
||||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -222,6 +224,7 @@ impl FlashAttnVarLen {
|
|||||||
k_l: &Layout,
|
k_l: &Layout,
|
||||||
v: &candle::CudaStorage,
|
v: &candle::CudaStorage,
|
||||||
v_l: &Layout,
|
v_l: &Layout,
|
||||||
|
is_bf16: bool,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
|
// https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327
|
||||||
let dev = q.device();
|
let dev = q.device();
|
||||||
@ -321,6 +324,7 @@ impl FlashAttnVarLen {
|
|||||||
.w()?;
|
.w()?;
|
||||||
|
|
||||||
let causal = if self.causal { 1 } else { 0 };
|
let causal = if self.causal { 1 } else { 0 };
|
||||||
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||||
@ -361,7 +365,7 @@ impl FlashAttnVarLen {
|
|||||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||||
/* is_causal */ causal,
|
/* is_causal */ causal,
|
||||||
/* is_bf16 */ 0,
|
/* is_bf16 */ is_bf16,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -397,8 +401,8 @@ impl candle::CustomOp3 for FlashAttnVarLen {
|
|||||||
v_l: &Layout,
|
v_l: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
match q.dtype() {
|
match q.dtype() {
|
||||||
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l),
|
candle::DType::F16 => self.cuda_fwd_t::<f16>(q, q_l, k, k_l, v, v_l, false),
|
||||||
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l),
|
candle::DType::BF16 => self.cuda_fwd_t::<bf16>(q, q_l, k, k_l, v, v_l, true),
|
||||||
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user