diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index cf8f0021..0e2e8093 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -13,7 +13,6 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } candle-datasets = { path = "../candle-datasets", version = "0.2.3" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true } candle-nn = { path = "../candle-nn", version = "0.2.3" } candle-transformers = { path = "../candle-transformers", version = "0.2.3" } cudarc = { workspace = true, optional = true } @@ -51,7 +50,7 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cudnn = ["candle/cudnn"] -flash-attn = ["cuda", "dep:candle-flash-attn"] +flash-attn = ["cuda", "candle-transformers/flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index aaa9b78a..95f3b8f4 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -41,6 +41,9 @@ struct Args { #[arg(long)] tracing: bool, + #[arg(long)] + use_flash_attn: bool, + /// The height in pixels of the generated image. #[arg(long)] height: Option, @@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> { let weights = weights.deserialize()?; let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); wuerstchen::prior::WPrior::new( - /* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280, - /* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb, + /* c_in */ PRIOR_CIN, + /* c */ 1536, + /* c_cond */ 1280, + /* c_r */ 64, + /* depth */ 32, + /* nhead */ 24, + args.use_flash_attn, + vb, )? }; let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; @@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> { /* c_cond */ 1024, /* clip_embd */ 1024, /* patch_size */ 2, + args.use_flash_attn, vb, )? }; diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 2faadad9..a3115c2b 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -12,6 +12,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true } candle-nn = { path = "../candle-nn", version = "0.2.3" } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } @@ -26,4 +27,5 @@ wav = { workspace = true } default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda"] +flash-attn = ["cuda", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] diff --git a/candle-transformers/src/models/wuerstchen/attention_processor.rs b/candle-transformers/src/models/wuerstchen/attention_processor.rs index 3f1a72eb..0b90cb9d 100644 --- a/candle-transformers/src/models/wuerstchen/attention_processor.rs +++ b/candle-transformers/src/models/wuerstchen/attention_processor.rs @@ -11,10 +11,33 @@ pub struct Attention { to_out: Linear, heads: usize, scale: f64, + use_flash_attn: bool, +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") } impl Attention { - pub fn new(query_dim: usize, heads: usize, dim_head: usize, vb: VarBuilder) -> Result { + pub fn new( + query_dim: usize, + heads: usize, + dim_head: usize, + use_flash_attn: bool, + vb: VarBuilder, + ) -> Result { let inner_dim = dim_head * heads; let scale = 1.0 / f64::sqrt(dim_head as f64); let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?; @@ -28,6 +51,7 @@ impl Attention { to_out, scale, heads, + use_flash_attn, }) } @@ -62,8 +86,28 @@ impl Attention { let key = self.head_to_batch_dim(&key)?; let value = self.head_to_batch_dim(&value)?; - let attn_prs = self.get_attention_scores(&query, &key)?; - let xs = attn_prs.matmul(&value)?; + let xs = if self.use_flash_attn { + let init_dtype = query.dtype(); + let q = query + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let k = key + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + let v = value + .to_dtype(candle::DType::F16)? + .unsqueeze(0)? + .transpose(1, 2)?; + flash_attn(&q, &k, &v, self.scale as f32, false)? + .transpose(1, 2)? + .squeeze(0)? + .to_dtype(init_dtype)? + } else { + let attn_prs = self.get_attention_scores(&query, &key)?; + attn_prs.matmul(&value)? + }; let xs = self.batch_to_head_dim(&xs)?; self.to_out diff --git a/candle-transformers/src/models/wuerstchen/common.rs b/candle-transformers/src/models/wuerstchen/common.rs index 8416a1f1..c89ec919 100644 --- a/candle-transformers/src/models/wuerstchen/common.rs +++ b/candle-transformers/src/models/wuerstchen/common.rs @@ -174,10 +174,11 @@ impl AttnBlock { c_cond: usize, nhead: usize, self_attn: bool, + use_flash_attn: bool, vb: VarBuilder, ) -> Result { let norm = WLayerNorm::new(c)?; - let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?; + let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?; let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?; Ok(Self { self_attn, diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 501a2776..64a48c8a 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -88,6 +88,7 @@ pub struct WDiffNeXt { } impl WDiffNeXt { + #[allow(clippy::too_many_arguments)] pub fn new( c_in: usize, c_out: usize, @@ -95,6 +96,7 @@ impl WDiffNeXt { c_cond: usize, clip_embd: usize, patch_size: usize, + use_flash_attn: bool, vb: VarBuilder, ) -> Result { const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280]; @@ -169,8 +171,14 @@ impl WDiffNeXt { let attn_block = if i == 0 { None } else { - let attn_block = - AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?; + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; layer_i += 1; Some(attn_block) }; @@ -208,8 +216,14 @@ impl WDiffNeXt { let attn_block = if i == 0 { None } else { - let attn_block = - AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?; + let attn_block = AttnBlock::new( + c_hidden, + c_cond, + NHEAD[i], + true, + use_flash_attn, + vb.pp(layer_i), + )?; layer_i += 1; Some(attn_block) }; diff --git a/candle-transformers/src/models/wuerstchen/prior.rs b/candle-transformers/src/models/wuerstchen/prior.rs index 168b70a6..97ccf0e2 100644 --- a/candle-transformers/src/models/wuerstchen/prior.rs +++ b/candle-transformers/src/models/wuerstchen/prior.rs @@ -21,6 +21,7 @@ pub struct WPrior { } impl WPrior { + #[allow(clippy::too_many_arguments)] pub fn new( c_in: usize, c: usize, @@ -28,6 +29,7 @@ impl WPrior { c_r: usize, depth: usize, nhead: usize, + use_flash_attn: bool, vb: VarBuilder, ) -> Result { let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?; @@ -44,6 +46,7 @@ impl WPrior { c, nhead, true, + use_flash_attn, vb.pp(format!("blocks.{}", 3 * index + 2)), )?; blocks.push(Block {