mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Add flash-attn support. (#912)
* Add flash-attn support. * Add the use-flash-attn flag. * Re-enable flash-attn.
This commit is contained in:
@ -88,6 +88,7 @@ pub struct WDiffNeXt {
|
||||
}
|
||||
|
||||
impl WDiffNeXt {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
c_in: usize,
|
||||
c_out: usize,
|
||||
@ -95,6 +96,7 @@ impl WDiffNeXt {
|
||||
c_cond: usize,
|
||||
clip_embd: usize,
|
||||
patch_size: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
||||
@ -169,8 +171,14 @@ impl WDiffNeXt {
|
||||
let attn_block = if i == 0 {
|
||||
None
|
||||
} else {
|
||||
let attn_block =
|
||||
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
||||
let attn_block = AttnBlock::new(
|
||||
c_hidden,
|
||||
c_cond,
|
||||
NHEAD[i],
|
||||
true,
|
||||
use_flash_attn,
|
||||
vb.pp(layer_i),
|
||||
)?;
|
||||
layer_i += 1;
|
||||
Some(attn_block)
|
||||
};
|
||||
@ -208,8 +216,14 @@ impl WDiffNeXt {
|
||||
let attn_block = if i == 0 {
|
||||
None
|
||||
} else {
|
||||
let attn_block =
|
||||
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
||||
let attn_block = AttnBlock::new(
|
||||
c_hidden,
|
||||
c_cond,
|
||||
NHEAD[i],
|
||||
true,
|
||||
use_flash_attn,
|
||||
vb.pp(layer_i),
|
||||
)?;
|
||||
layer_i += 1;
|
||||
Some(attn_block)
|
||||
};
|
||||
|
Reference in New Issue
Block a user