/****************************************************************************** * Copyright (c) 2023, Tri Dao. ******************************************************************************/ #pragma once namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct BlockInfo { template __device__ BlockInfo(const Params ¶ms, const int bidb) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) , actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) { } template inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; } const int sum_s_q; const int sum_s_k; const uint32_t actual_seqlen_q; const uint32_t actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace flash