mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add flash-attn support. (#912)
* Add flash-attn support. * Add the use-flash-attn flag. * Re-enable flash-attn.
This commit is contained in:
@ -41,6 +41,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
/* c_in */ PRIOR_CIN,
|
||||
/* c */ 1536,
|
||||
/* c_cond */ 1280,
|
||||
/* c_r */ 64,
|
||||
/* depth */ 32,
|
||||
/* nhead */ 24,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> {
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
|
Reference in New Issue
Block a user