diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index b610915b..61980a58 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -26,6 +26,7 @@ impl FlashAttn { k_l: &Layout, v: &candle::CudaStorage, v_l: &Layout, + is_bf16: bool, ) -> Result<(candle::CudaStorage, Shape)> { // https://github.com/Dao-AILab/flash-attention/blob/b252072409e69c25f2b9d473cc534e49b24decd2/csrc/flash_attn/flash_api.cpp#L187 let dev = q.device(); @@ -94,6 +95,7 @@ impl FlashAttn { let softmax_lse = dev.alloc_zeros::(b_sz * num_heads * seqlen_q).w()?; let causal = if self.causal { 1 } else { 0 }; + let is_bf16 = if is_bf16 { 1 } else { 0 }; unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; @@ -132,7 +134,7 @@ impl FlashAttn { /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_causal */ causal, - /* is_bf16 */ 0, + /* is_bf16 */ is_bf16, ) } @@ -168,8 +170,8 @@ impl candle::CustomOp3 for FlashAttn { v_l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { match q.dtype() { - candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), - candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), + candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, false), + candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, true), dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), } } @@ -222,6 +224,7 @@ impl FlashAttnVarLen { k_l: &Layout, v: &candle::CudaStorage, v_l: &Layout, + is_bf16: bool, ) -> Result<(candle::CudaStorage, Shape)> { // https://github.com/Dao-AILab/flash-attention/blob/184b992dcb2a0890adaa19eb9b541c3e4f9d2a08/csrc/flash_attn/flash_api.cpp#L327 let dev = q.device(); @@ -321,6 +324,7 @@ impl FlashAttnVarLen { .w()?; let causal = if self.causal { 1 } else { 0 }; + let is_bf16 = if is_bf16 { 1 } else { 0 }; unsafe { let q_ptr = *q.device_ptr() as *const core::ffi::c_void; @@ -361,7 +365,7 @@ impl FlashAttnVarLen { /* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32, /* is_causal */ causal, - /* is_bf16 */ 0, + /* is_bf16 */ is_bf16, ) } @@ -397,8 +401,8 @@ impl candle::CustomOp3 for FlashAttnVarLen { v_l: &Layout, ) -> Result<(candle::CudaStorage, Shape)> { match q.dtype() { - candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), - candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l), + candle::DType::F16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, false), + candle::DType::BF16 => self.cuda_fwd_t::(q, q_l, k, k_l, v, v_l, true), dt => candle::bail!("flash-attn is only supported for f16/bf16 ({dt:?})"), } }