mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
chore: update flash attention kernels (#1518)
* chore: update flash attention kernels * fmt * remove unused kernels * force f32 * correct stride
This commit is contained in:
@ -3,12 +3,14 @@ mod ffi;
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||
use half::{bf16, f16};
|
||||
|
||||
pub struct FlashAttn {
|
||||
pub softmax_scale: f32,
|
||||
pub causal: bool,
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
}
|
||||
|
||||
fn round_multiple(x: usize, m: usize) -> usize {
|
||||
@ -85,6 +87,51 @@ impl FlashAttn {
|
||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||
}
|
||||
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
"DType mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes.dtype(),
|
||||
DType::F32
|
||||
);
|
||||
}
|
||||
|
||||
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
|
||||
|
||||
if num_heads != alibi_slopes_layout.shape().dims1()? {
|
||||
candle::bail!(
|
||||
"shape mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes_layout.shape(),
|
||||
(num_heads)
|
||||
);
|
||||
}
|
||||
|
||||
let alibi_slopes = match &*alibi_slopes {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
|
||||
};
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
|
||||
// if window_size_left > self.max_seqlen_k or None => -1
|
||||
let mut window_size_left = self
|
||||
.window_size_left
|
||||
.filter(|v| v <= &seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
// if window_size_right > self.max_seqlen_k or None => -1
|
||||
let mut window_size_right = self
|
||||
.window_size_right
|
||||
.filter(|v| v <= &seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
let head_size = round_multiple(head_size_og, 8);
|
||||
let head_size_rounded = round_multiple(head_size, 32);
|
||||
let seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
@ -94,9 +141,22 @@ impl FlashAttn {
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
let is_causal = if window_size_left < 0 && window_size_right == 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if window_size_left < 0 && window_size_right >= 0 {
|
||||
window_size_left = seqlen_k as i32;
|
||||
}
|
||||
if window_size_left >= 0 && window_size_right < 0 {
|
||||
window_size_right = seqlen_k as i32;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
@ -109,12 +169,14 @@ impl FlashAttn {
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||
/* q_batch_stride */ q_stride[0] as u32,
|
||||
/* k_batch_stride */ k_stride[0] as u32,
|
||||
/* v_batch_stride */ v_stride[0] as u32,
|
||||
/* o_batch_stride */ o_stride[0] as u32,
|
||||
/* alibi_slopes_batch_stride */ 0,
|
||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||
@ -133,8 +195,10 @@ impl FlashAttn {
|
||||
/* seqlen_k */ seqlen_k as u32,
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
)
|
||||
}
|
||||
|
||||
@ -197,20 +261,137 @@ pub fn flash_attn(
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
causal,
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_alibi(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
/// Flash-attention v2 layer.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
///
|
||||
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_alibi_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttn {
|
||||
softmax_scale,
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
struct FlashAttnVarLen {
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
seqlens_q: Tensor,
|
||||
seqlens_k: Tensor,
|
||||
pub softmax_scale: f32,
|
||||
pub max_seqlen_q: usize,
|
||||
pub max_seqlen_k: usize,
|
||||
pub seqlens_q: Tensor,
|
||||
pub seqlens_k: Tensor,
|
||||
pub alibi_slopes: Option<Tensor>,
|
||||
pub window_size_left: Option<usize>,
|
||||
pub window_size_right: Option<usize>,
|
||||
}
|
||||
|
||||
impl FlashAttnVarLen {
|
||||
@ -311,7 +492,54 @@ impl FlashAttnVarLen {
|
||||
if nseqlens_k != nseqlens_q {
|
||||
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
|
||||
}
|
||||
|
||||
let batch_size = nseqlens_q - 1;
|
||||
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
"DType mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes.dtype(),
|
||||
DType::F32
|
||||
);
|
||||
}
|
||||
|
||||
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
|
||||
|
||||
if num_heads != alibi_slopes_layout.shape().dims1()? {
|
||||
candle::bail!(
|
||||
"shape mismatch alibi_slopes {:?}, expected {:?}",
|
||||
alibi_slopes_layout.shape(),
|
||||
(num_heads)
|
||||
);
|
||||
}
|
||||
|
||||
let alibi_slopes = match &*alibi_slopes {
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
|
||||
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
|
||||
};
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
|
||||
// if window_size_left > self.max_seqlen_k or None => -1
|
||||
let mut window_size_left = self
|
||||
.window_size_left
|
||||
.filter(|v| v <= &self.max_seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
// if window_size_right > self.max_seqlen_k or None => -1
|
||||
let mut window_size_right = self
|
||||
.window_size_right
|
||||
.filter(|v| v <= &self.max_seqlen_k)
|
||||
.map(|v| v as i32)
|
||||
.unwrap_or(-1);
|
||||
|
||||
let head_size = round_multiple(head_size_og, 8);
|
||||
let head_size_rounded = round_multiple(head_size, 32);
|
||||
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
|
||||
@ -323,9 +551,22 @@ impl FlashAttnVarLen {
|
||||
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
|
||||
.w()?;
|
||||
|
||||
let causal = if self.causal { 1 } else { 0 };
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
let is_causal = if window_size_left < 0 && window_size_right == 0 {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if window_size_left < 0 && window_size_right >= 0 {
|
||||
window_size_left = self.max_seqlen_k as i32;
|
||||
}
|
||||
if window_size_left >= 0 && window_size_right < 0 {
|
||||
window_size_right = self.max_seqlen_k as i32;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
@ -340,12 +581,14 @@ impl FlashAttnVarLen {
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
||||
/* q_batch_stride */ 0,
|
||||
/* k_batch_stride */ 0,
|
||||
/* v_batch_stride */ 0,
|
||||
/* o_batch_stride */ 0,
|
||||
/* alibi_slopes_batch_stride */ 0,
|
||||
/* q_row_stride */ q_stride[q_rank - 3] as u32,
|
||||
/* k_row_stride */ k_stride[k_rank - 3] as u32,
|
||||
/* v_row_stride */ v_stride[v_rank - 3] as u32,
|
||||
@ -364,8 +607,10 @@ impl FlashAttnVarLen {
|
||||
/* seqlen_k */ self.max_seqlen_k as u32,
|
||||
/* seqlen_q_rounded */ seqlen_q_rounded as u32,
|
||||
/* seqlen_k_rounded */ seqlen_k_rounded as u32,
|
||||
/* is_causal */ causal,
|
||||
/* is_bf16 */ is_bf16,
|
||||
/* is_causal */ is_causal,
|
||||
/* window_size_left */ window_size_left,
|
||||
/* window_size_right */ window_size_right,
|
||||
)
|
||||
}
|
||||
|
||||
@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
causal,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
pub fn flash_attn_varlen_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: None,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
pub fn flash_attn_varlen_alibi(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
let window_size_left = None;
|
||||
let window_size_right = if causal { Some(0) } else { None };
|
||||
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Flash-attention v2 layer with variable-length batching.
|
||||
///
|
||||
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
|
||||
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
|
||||
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
|
||||
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
|
||||
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
|
||||
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
|
||||
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
|
||||
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
|
||||
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
|
||||
/// * `window_size_left` - Limit left attention to value tokens.
|
||||
/// * `window_size_right` - Limit right attention to value tokens.
|
||||
///
|
||||
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
|
||||
/// `seqlen_1 + seqlen_2`, etc.
|
||||
///
|
||||
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
|
||||
///
|
||||
/// # Causal mask
|
||||
///
|
||||
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
|
||||
/// of `Q @ K^T`
|
||||
pub fn flash_attn_varlen_alibi_windowed(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
alibi_slopes: &Tensor,
|
||||
seqlens_q: &Tensor,
|
||||
seqlens_k: &Tensor,
|
||||
max_seqlen_q: usize,
|
||||
max_seqlen_k: usize,
|
||||
softmax_scale: f32,
|
||||
window_size_left: Option<usize>,
|
||||
window_size_right: Option<usize>,
|
||||
) -> Result<Tensor> {
|
||||
let op = FlashAttnVarLen {
|
||||
softmax_scale,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqlens_q: seqlens_q.clone(),
|
||||
seqlens_k: seqlens_k.clone(),
|
||||
alibi_slopes: Some(alibi_slopes.clone()),
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
};
|
||||
q.apply_op3(k, v, op)
|
||||
}
|
||||
|
Reference in New Issue
Block a user