Add flash-attn support. (#912)

* Add flash-attn support.

* Add the use-flash-attn flag.

* Re-enable flash-attn.
This commit is contained in:
Laurent Mazare
2023-09-20 14:07:55 +01:00
committed by GitHub
parent 728e167334
commit fb1c2ac535
7 changed files with 85 additions and 12 deletions

View File

@ -11,10 +11,33 @@ pub struct Attention {
to_out: Linear,
heads: usize,
scale: f64,
use_flash_attn: bool,
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
impl Attention {
pub fn new(query_dim: usize, heads: usize, dim_head: usize, vb: VarBuilder) -> Result<Self> {
pub fn new(
query_dim: usize,
heads: usize,
dim_head: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let inner_dim = dim_head * heads;
let scale = 1.0 / f64::sqrt(dim_head as f64);
let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
@ -28,6 +51,7 @@ impl Attention {
to_out,
scale,
heads,
use_flash_attn,
})
}
@ -62,8 +86,28 @@ impl Attention {
let key = self.head_to_batch_dim(&key)?;
let value = self.head_to_batch_dim(&value)?;
let attn_prs = self.get_attention_scores(&query, &key)?;
let xs = attn_prs.matmul(&value)?;
let xs = if self.use_flash_attn {
let init_dtype = query.dtype();
let q = query
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let k = key
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
let v = value
.to_dtype(candle::DType::F16)?
.unsqueeze(0)?
.transpose(1, 2)?;
flash_attn(&q, &k, &v, self.scale as f32, false)?
.transpose(1, 2)?
.squeeze(0)?
.to_dtype(init_dtype)?
} else {
let attn_prs = self.get_attention_scores(&query, &key)?;
attn_prs.matmul(&value)?
};
let xs = self.batch_to_head_dim(&xs)?;
self.to_out

View File

@ -174,10 +174,11 @@ impl AttnBlock {
c_cond: usize,
nhead: usize,
self_attn: bool,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let norm = WLayerNorm::new(c)?;
let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?;
let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
Ok(Self {
self_attn,

View File

@ -88,6 +88,7 @@ pub struct WDiffNeXt {
}
impl WDiffNeXt {
#[allow(clippy::too_many_arguments)]
pub fn new(
c_in: usize,
c_out: usize,
@ -95,6 +96,7 @@ impl WDiffNeXt {
c_cond: usize,
clip_embd: usize,
patch_size: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
@ -169,8 +171,14 @@ impl WDiffNeXt {
let attn_block = if i == 0 {
None
} else {
let attn_block =
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
let attn_block = AttnBlock::new(
c_hidden,
c_cond,
NHEAD[i],
true,
use_flash_attn,
vb.pp(layer_i),
)?;
layer_i += 1;
Some(attn_block)
};
@ -208,8 +216,14 @@ impl WDiffNeXt {
let attn_block = if i == 0 {
None
} else {
let attn_block =
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
let attn_block = AttnBlock::new(
c_hidden,
c_cond,
NHEAD[i],
true,
use_flash_attn,
vb.pp(layer_i),
)?;
layer_i += 1;
Some(attn_block)
};

View File

@ -21,6 +21,7 @@ pub struct WPrior {
}
impl WPrior {
#[allow(clippy::too_many_arguments)]
pub fn new(
c_in: usize,
c: usize,
@ -28,6 +29,7 @@ impl WPrior {
c_r: usize,
depth: usize,
nhead: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
@ -44,6 +46,7 @@ impl WPrior {
c,
nhead,
true,
use_flash_attn,
vb.pp(format!("blocks.{}", 3 * index + 2)),
)?;
blocks.push(Block {