diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 61980a58..3395bd0d 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -233,8 +233,8 @@ impl FlashAttnVarLen { let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); let seqlens_q = match &*seqlens_q { - candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"), candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_q must be a cuda tensor"), }; let seqlens_q = match seqlens_q_layout.contiguous_offsets() { Some((o1, o2)) => seqlens_q.slice(o1..o2), @@ -243,8 +243,8 @@ impl FlashAttnVarLen { let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); let seqlens_k = match &*seqlens_k { - candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"), candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_k must be a cuda tensor"), }; let seqlens_k = match seqlens_k_layout.contiguous_offsets() { Some((o1, o2)) => seqlens_k.slice(o1..o2),