mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add flash-attn support. (#912)
* Add flash-attn support. * Add the use-flash-attn flag. * Re-enable flash-attn.
This commit is contained in:
@ -13,7 +13,6 @@ readme = "README.md"
|
|||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
@ -51,7 +50,7 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
|
|
||||||
|
@ -41,6 +41,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
/// The height in pixels of the generated image.
|
/// The height in pixels of the generated image.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
height: Option<usize>,
|
height: Option<usize>,
|
||||||
@ -289,8 +292,14 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
wuerstchen::prior::WPrior::new(
|
wuerstchen::prior::WPrior::new(
|
||||||
/* c_in */ PRIOR_CIN, /* c */ 1536, /* c_cond */ 1280,
|
/* c_in */ PRIOR_CIN,
|
||||||
/* c_r */ 64, /* depth */ 32, /* nhead */ 24, vb,
|
/* c */ 1536,
|
||||||
|
/* c_cond */ 1280,
|
||||||
|
/* c_r */ 64,
|
||||||
|
/* depth */ 32,
|
||||||
|
/* nhead */ 24,
|
||||||
|
args.use_flash_attn,
|
||||||
|
vb,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||||
@ -337,6 +346,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
/* c_cond */ 1024,
|
/* c_cond */ 1024,
|
||||||
/* clip_embd */ 1024,
|
/* clip_embd */ 1024,
|
||||||
/* patch_size */ 2,
|
/* patch_size */ 2,
|
||||||
|
args.use_flash_attn,
|
||||||
vb,
|
vb,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
@ -12,6 +12,7 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
||||||
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
@ -26,4 +27,5 @@ wav = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||||
|
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||||
|
@ -11,10 +11,33 @@ pub struct Attention {
|
|||||||
to_out: Linear,
|
to_out: Linear,
|
||||||
heads: usize,
|
heads: usize,
|
||||||
scale: f64,
|
scale: f64,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "flash-attn")]
|
||||||
|
fn flash_attn(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||||
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
pub fn new(query_dim: usize, heads: usize, dim_head: usize, vb: VarBuilder) -> Result<Self> {
|
pub fn new(
|
||||||
|
query_dim: usize,
|
||||||
|
heads: usize,
|
||||||
|
dim_head: usize,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
let inner_dim = dim_head * heads;
|
let inner_dim = dim_head * heads;
|
||||||
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
let scale = 1.0 / f64::sqrt(dim_head as f64);
|
||||||
let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
|
let to_q = linear(query_dim, inner_dim, vb.pp("to_q"))?;
|
||||||
@ -28,6 +51,7 @@ impl Attention {
|
|||||||
to_out,
|
to_out,
|
||||||
scale,
|
scale,
|
||||||
heads,
|
heads,
|
||||||
|
use_flash_attn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,8 +86,28 @@ impl Attention {
|
|||||||
let key = self.head_to_batch_dim(&key)?;
|
let key = self.head_to_batch_dim(&key)?;
|
||||||
let value = self.head_to_batch_dim(&value)?;
|
let value = self.head_to_batch_dim(&value)?;
|
||||||
|
|
||||||
|
let xs = if self.use_flash_attn {
|
||||||
|
let init_dtype = query.dtype();
|
||||||
|
let q = query
|
||||||
|
.to_dtype(candle::DType::F16)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = key
|
||||||
|
.to_dtype(candle::DType::F16)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = value
|
||||||
|
.to_dtype(candle::DType::F16)?
|
||||||
|
.unsqueeze(0)?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
flash_attn(&q, &k, &v, self.scale as f32, false)?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.squeeze(0)?
|
||||||
|
.to_dtype(init_dtype)?
|
||||||
|
} else {
|
||||||
let attn_prs = self.get_attention_scores(&query, &key)?;
|
let attn_prs = self.get_attention_scores(&query, &key)?;
|
||||||
let xs = attn_prs.matmul(&value)?;
|
attn_prs.matmul(&value)?
|
||||||
|
};
|
||||||
let xs = self.batch_to_head_dim(&xs)?;
|
let xs = self.batch_to_head_dim(&xs)?;
|
||||||
|
|
||||||
self.to_out
|
self.to_out
|
||||||
|
@ -174,10 +174,11 @@ impl AttnBlock {
|
|||||||
c_cond: usize,
|
c_cond: usize,
|
||||||
nhead: usize,
|
nhead: usize,
|
||||||
self_attn: bool,
|
self_attn: bool,
|
||||||
|
use_flash_attn: bool,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let norm = WLayerNorm::new(c)?;
|
let norm = WLayerNorm::new(c)?;
|
||||||
let attention = Attention::new(c, nhead, c / nhead, vb.pp("attention"))?;
|
let attention = Attention::new(c, nhead, c / nhead, use_flash_attn, vb.pp("attention"))?;
|
||||||
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
let kv_mapper_lin = candle_nn::linear(c_cond, c, vb.pp("kv_mapper.1"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
|
@ -88,6 +88,7 @@ pub struct WDiffNeXt {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl WDiffNeXt {
|
impl WDiffNeXt {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
c_in: usize,
|
c_in: usize,
|
||||||
c_out: usize,
|
c_out: usize,
|
||||||
@ -95,6 +96,7 @@ impl WDiffNeXt {
|
|||||||
c_cond: usize,
|
c_cond: usize,
|
||||||
clip_embd: usize,
|
clip_embd: usize,
|
||||||
patch_size: usize,
|
patch_size: usize,
|
||||||
|
use_flash_attn: bool,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
|
||||||
@ -169,8 +171,14 @@ impl WDiffNeXt {
|
|||||||
let attn_block = if i == 0 {
|
let attn_block = if i == 0 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let attn_block =
|
let attn_block = AttnBlock::new(
|
||||||
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
c_hidden,
|
||||||
|
c_cond,
|
||||||
|
NHEAD[i],
|
||||||
|
true,
|
||||||
|
use_flash_attn,
|
||||||
|
vb.pp(layer_i),
|
||||||
|
)?;
|
||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
Some(attn_block)
|
Some(attn_block)
|
||||||
};
|
};
|
||||||
@ -208,8 +216,14 @@ impl WDiffNeXt {
|
|||||||
let attn_block = if i == 0 {
|
let attn_block = if i == 0 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let attn_block =
|
let attn_block = AttnBlock::new(
|
||||||
AttnBlock::new(c_hidden, c_cond, NHEAD[i], true, vb.pp(layer_i))?;
|
c_hidden,
|
||||||
|
c_cond,
|
||||||
|
NHEAD[i],
|
||||||
|
true,
|
||||||
|
use_flash_attn,
|
||||||
|
vb.pp(layer_i),
|
||||||
|
)?;
|
||||||
layer_i += 1;
|
layer_i += 1;
|
||||||
Some(attn_block)
|
Some(attn_block)
|
||||||
};
|
};
|
||||||
|
@ -21,6 +21,7 @@ pub struct WPrior {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl WPrior {
|
impl WPrior {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn new(
|
pub fn new(
|
||||||
c_in: usize,
|
c_in: usize,
|
||||||
c: usize,
|
c: usize,
|
||||||
@ -28,6 +29,7 @@ impl WPrior {
|
|||||||
c_r: usize,
|
c_r: usize,
|
||||||
depth: usize,
|
depth: usize,
|
||||||
nhead: usize,
|
nhead: usize,
|
||||||
|
use_flash_attn: bool,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
|
||||||
@ -44,6 +46,7 @@ impl WPrior {
|
|||||||
c,
|
c,
|
||||||
nhead,
|
nhead,
|
||||||
true,
|
true,
|
||||||
|
use_flash_attn,
|
||||||
vb.pp(format!("blocks.{}", 3 * index + 2)),
|
vb.pp(format!("blocks.{}", 3 * index + 2)),
|
||||||
)?;
|
)?;
|
||||||
blocks.push(Block {
|
blocks.push(Block {
|
||||||
|
Reference in New Issue
Block a user