mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fix for flash-attn. (#1310)
Co-authored-by: laurent <laurent@par2dc5-ai-prd-cl01dgx02.cm.cluster>
This commit is contained in:
@ -233,8 +233,8 @@ impl FlashAttnVarLen {
|
||||
|
||||
let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout();
|
||||
let seqlens_q = match &*seqlens_q {
|
||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||
_ => candle::bail!("seqlens_q must be a cuda tensor"),
|
||||
};
|
||||
let seqlens_q = match seqlens_q_layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => seqlens_q.slice(o1..o2),
|
||||
@ -243,8 +243,8 @@ impl FlashAttnVarLen {
|
||||
|
||||
let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout();
|
||||
let seqlens_k = match &*seqlens_k {
|
||||
candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?, // Should be i32!
|
||||
_ => candle::bail!("seqlens_k must be a cuda tensor"),
|
||||
};
|
||||
let seqlens_k = match seqlens_k_layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => seqlens_k.slice(o1..o2),
|
||||
|
Reference in New Issue
Block a user