Add flash-attn support. (#912)

* Add flash-attn support.

* Add the use-flash-attn flag.

* Re-enable flash-attn.
This commit is contained in:
Laurent Mazare
2023-09-20 14:07:55 +01:00
committed by GitHub
parent 728e167334
commit fb1c2ac535
7 changed files with 85 additions and 12 deletions

View File

@ -41,6 +41,9 @@ struct Args {
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> {
let weights = weights.deserialize()?;
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
wuerstchen::prior::WPrior::new(
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
/* c_in */ PRIOR_CIN,
/* c */ 1536,
/* c_cond */ 1280,
/* c_r */ 64,
/* depth */ 32,
/* nhead */ 24,
args.use_flash_attn,
vb,
)?
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> {
/* c_cond */ 1024,
/* clip_embd */ 1024,
/* patch_size */ 2,
args.use_flash_attn,
vb,
)?
};