mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add flash-attn support. (#912)
* Add flash-attn support. * Add the use-flash-attn flag. * Re-enable flash-attn.
This commit is contained in:
@ -13,7 +13,6 @@ readme = "README.md"
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
@ -51,7 +50,7 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
|
||||
|
@ -41,6 +41,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> {
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||
wuerstchen::prior::WPrior::new(
|
||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
||||
/* c_in */ PRIOR_CIN,
|
||||
/* c */ 1536,
|
||||
/* c_cond */ 1280,
|
||||
/* c_r */ 64,
|
||||
/* depth */ 32,
|
||||
/* nhead */ 24,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> {
|
||||
/* c_cond */ 1024,
|
||||
/* clip_embd */ 1024,
|
||||
/* patch_size */ 2,
|
||||
args.use_flash_attn,
|
||||
vb,
|
||||
)?
|
||||
};
|
||||
|
@ -12,6 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
@ -26,4 +27,5 @@ wav = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||
|
@ -11,10 +11,33 @@ pub struct Attention {
|
||||
to_out: Linear,
|
||||
heads: usize,
|
||||
scale: f64,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn new(query_dim: usize, heads: usize, dim_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(
|
||||
query_dim: usize,
|
||||
heads: usize,
|
||||
dim_head: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let inner_dim = dim_head * heads;
|
||||
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
||||
let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
|
||||
@ -28,6 +51,7 @@ impl Attention {
|
||||
to_out,
|
||||
scale,
|
||||
heads,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -62,8 +86,28 @@ impl Attention {
|
||||
let key = self.head_to_batch_dim(&key)?;
|
||||
let value = self.head_to_batch_dim(&value)?;
|
||||
|
||||
let attn_prs = self.get_attention_scores(&query, &key)?;
|
||||
let xs = attn_prs.matmul(&value)?;
|
||||
let xs = if self.use_flash_attn {
|
||||
let init_dtype = query.dtype();
|
||||
let q = query
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let k = key
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
let v = value
|
||||
.to_dtype(candle::DType::F16)?
|
||||
.unsqueeze(0)?
|
||||
.transpose(1, 2)?;
|
||||
flash_attn(&q, &k, &v, self.scale as f32, false)?
|
||||
.transpose(1, 2)?
|
||||
.squeeze(0)?
|
||||
.to_dtype(init_dtype)?
|
||||
} else {
|
||||
let attn_prs = self.get_attention_scores(&query, &key)?;
|
||||
attn_prs.matmul(&value)?
|
||||
};
|
||||
let xs = self.batch_to_head_dim(&xs)?;
|
||||
|
||||
self.to_out
|
||||
|
@ -174,10 +174,11 @@ impl AttnBlock {
|
||||
c_cond: usize,
|
||||
nhead: usize,
|
||||
self_attn: bool,
|
||||
use_flash_attn: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm = WLayerNorm::new(c)?;
|
||||
let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?;
|
||||
let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
|
||||
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
|
@ -88,6 +88,7 @@ pub struct WDiffNeXt {
|
||||
}
|
||||
|
||||
impl WDiffNeXt {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
c_in: usize,
|
||||
c_out: usize,
|
||||
@ -95,6 +96,7 @@ impl WDiffNeXt {
|
||||
c_cond: usize,
|
||||
clip_embd: usize,
|
||||
patch_size: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
||||
@ -169,8 +171,14 @@ impl WDiffNeXt {
|
||||
let attn_block = if i == 0 {
|
||||
None
|
||||
} else {
|
||||
let attn_block =
|
||||
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
||||
let attn_block = AttnBlock::new(
|
||||
c_hidden,
|
||||
c_cond,
|
||||
NHEAD[i],
|
||||
true,
|
||||
use_flash_attn,
|
||||
vb.pp(layer_i),
|
||||
)?;
|
||||
layer_i += 1;
|
||||
Some(attn_block)
|
||||
};
|
||||
@ -208,8 +216,14 @@ impl WDiffNeXt {
|
||||
let attn_block = if i == 0 {
|
||||
None
|
||||
} else {
|
||||
let attn_block =
|
||||
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
||||
let attn_block = AttnBlock::new(
|
||||
c_hidden,
|
||||
c_cond,
|
||||
NHEAD[i],
|
||||
true,
|
||||
use_flash_attn,
|
||||
vb.pp(layer_i),
|
||||
)?;
|
||||
layer_i += 1;
|
||||
Some(attn_block)
|
||||
};
|
||||
|
@ -21,6 +21,7 @@ pub struct WPrior {
|
||||
}
|
||||
|
||||
impl WPrior {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
c_in: usize,
|
||||
c: usize,
|
||||
@ -28,6 +29,7 @@ impl WPrior {
|
||||
c_r: usize,
|
||||
depth: usize,
|
||||
nhead: usize,
|
||||
use_flash_attn: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
||||
@ -44,6 +46,7 @@ impl WPrior {
|
||||
c,
|
||||
nhead,
|
||||
true,
|
||||
use_flash_attn,
|
||||
vb.pp(format!("blocks.{}", 3 * index + 2)),
|
||||
)?;
|
||||
blocks.push(Block {
|
||||
|
Reference in New Issue
Block a user