Add flash-attn support. (#912)

* Add flash-attn support.

* Add the use-flash-attn flag.

* Re-enable flash-attn.
This commit is contained in:
Laurent Mazare
2023-09-20 14:07:55 +01:00
committed by GitHub
parent 728e167334
commit fb1c2ac535
7 changed files with 85 additions and 12 deletions

View File

@ -88,6 +88,7 @@ pub struct WDiffNeXt {
}
impl WDiffNeXt {
#[allow(clippy::too_many_arguments)]
pub fn new(
c_in: usize,
c_out: usize,
@ -95,6 +96,7 @@ impl WDiffNeXt {
c_cond: usize,
clip_embd: usize,
patch_size: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
@ -169,8 +171,14 @@ impl WDiffNeXt {
let attn_block = if i == 0 {
None
} else {
let attn_block =
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
let attn_block = AttnBlock::new(
c_hidden,
c_cond,
NHEAD[i],
true,
use_flash_attn,
vb.pp(layer_i),
)?;
layer_i += 1;
Some(attn_block)
};
@ -208,8 +216,14 @@ impl WDiffNeXt {
let attn_block = if i == 0 {
None
} else {
let attn_block =
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
let attn_block = AttnBlock::new(
c_hidden,
c_cond,
NHEAD[i],
true,
use_flash_attn,
vb.pp(layer_i),
)?;
layer_i += 1;
Some(attn_block)
};