From d2c3f1477397b6730fbef7225dd9e5fc0a9fa096 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 10 Nov 2023 10:27:27 +0100 Subject: [PATCH] Fix for flash-attn. (#1310) Co-authored-by: laurent --- candle-flash-attn/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index 61980a58..3395bd0d 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -233,8 +233,8 @@ impl FlashAttnVarLen { let (seqlens_q, seqlens_q_layout) = self.seqlens_q.storage_and_layout(); let seqlens_q = match &*seqlens_q { - candle::Storage::Cpu(_) => candle::bail!("seqlens_q must be a cuda tensor"), candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_q must be a cuda tensor"), }; let seqlens_q = match seqlens_q_layout.contiguous_offsets() { Some((o1, o2)) => seqlens_q.slice(o1..o2), @@ -243,8 +243,8 @@ impl FlashAttnVarLen { let (seqlens_k, seqlens_k_layout) = self.seqlens_k.storage_and_layout(); let seqlens_k = match &*seqlens_k { - candle::Storage::Cpu(_) => candle::bail!("seqlens_k must be a cuda tensor"), candle::Storage::Cuda(c) => c.as_cuda_slice::()?, // Should be i32! + _ => candle::bail!("seqlens_k must be a cuda tensor"), }; let seqlens_k = match seqlens_k_layout.contiguous_offsets() { Some((o1, o2)) => seqlens_k.slice(o1..o2),