mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add the flux model for image generation. (#2390)
* Add the flux autoencoder. * Add the encoder down-blocks. * Upsampling in the decoder. * Sketch the flow matching model. * More flux model. * Add some of the positional embeddings. * Add the rope embeddings. * Add the sampling functions. * Add the flux example. * Fix the T5 bits. * Proper T5 tokenizer. * Clip encoder path fix. * Get the clip embeddings. * No configurable weights in layer norm. * More weights related fixes. * Yet another shape fix. * DType fix. * Fix a couple more shape issues. * DType fixes. * Fix the latent dims. * Fix more shape issues. * Autoencoder fixes. * Get some generations out. * Bugfix. * T5 padding. * Clippy fix. * Add the decode only mode. * Fix. * More fixes. * Finally get some generations to work. * Add readme.
This commit is contained in:
19
candle-examples/examples/flux/README.md
Normal file
19
candle-examples/examples/flux/README.md
Normal file
@ -0,0 +1,19 @@
|
||||
# candle-flux: image generation with latent rectified flow transformers
|
||||
|
||||

|
||||
|
||||
Flux is a 12B rectified flow transformer capable of generating images from text
|
||||
descriptions,
|
||||
[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell),
|
||||
[github](https://github.com/black-forest-labs/flux),
|
||||
[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
|
||||
|
||||
|
||||
## Running the model
|
||||
|
||||
```bash
|
||||
cargo run --features cuda --example flux -r -- \
|
||||
--height 1024 --width 1024
|
||||
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
|
||||
```
|
||||
|
BIN
candle-examples/examples/flux/assets/flux-robot.jpg
Normal file
BIN
candle-examples/examples/flux/assets/flux-robot.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 90 KiB |
182
candle-examples/examples/flux/main.rs
Normal file
182
candle-examples/examples/flux/main.rs
Normal file
@ -0,0 +1,182 @@
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use candle_transformers::models::{clip, flux, t5};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{IndexOp, Module, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The prompt to be used for image generation.
|
||||
#[arg(long, default_value = "A rusty robot walking on a beach")]
|
||||
prompt: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The height in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
height: Option<usize>,
|
||||
|
||||
/// The width in pixels of the generated image.
|
||||
#[arg(long)]
|
||||
width: Option<usize>,
|
||||
|
||||
#[arg(long)]
|
||||
decode_only: Option<String>,
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let Args {
|
||||
prompt,
|
||||
cpu,
|
||||
height,
|
||||
width,
|
||||
tracing,
|
||||
decode_only,
|
||||
} = args;
|
||||
let width = width.unwrap_or(1360);
|
||||
let height = height.unwrap_or(768);
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let bf_repo = api.repo(hf_hub::Repo::model(
|
||||
"black-forest-labs/FLUX.1-schnell".to_string(),
|
||||
));
|
||||
let device = candle_examples::device(cpu)?;
|
||||
let dtype = device.bf16_default_to_f32();
|
||||
let img = match decode_only {
|
||||
None => {
|
||||
let t5_emb = {
|
||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||
"google/t5-v1_1-xxl".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/2".to_string(),
|
||||
));
|
||||
let model_file = repo.get("model.safetensors")?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: t5::Config = serde_json::from_str(&config)?;
|
||||
let mut model = t5::T5EncoderModel::load(vb, &config)?;
|
||||
let tokenizer_filename = api
|
||||
.model("lmz/mt5-tokenizers".to_string())
|
||||
.get("t5-v1_1-xxl.tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
tokens.resize(256, 0);
|
||||
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||
println!("{input_token_ids}");
|
||||
model.forward(&input_token_ids)?
|
||||
};
|
||||
println!("T5\n{t5_emb}");
|
||||
let clip_emb = {
|
||||
let repo = api.repo(hf_hub::Repo::model(
|
||||
"openai/clip-vit-large-patch14".to_string(),
|
||||
));
|
||||
let model_file = repo.get("model.safetensors")?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
||||
let config = clip::text_model::ClipTextConfig {
|
||||
vocab_size: 49408,
|
||||
projection_dim: 768,
|
||||
activation: clip::text_model::Activation::QuickGelu,
|
||||
intermediate_size: 3072,
|
||||
embed_dim: 768,
|
||||
max_position_embeddings: 77,
|
||||
pad_with: None,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
};
|
||||
let model =
|
||||
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?;
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(prompt.as_str(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||
println!("{input_token_ids}");
|
||||
model.forward(&input_token_ids)?
|
||||
};
|
||||
println!("CLIP\n{clip_emb}");
|
||||
let img = {
|
||||
let model_file = bf_repo.get("flux1-schnell.sft")?;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = flux::model::Config::schnell();
|
||||
let model = flux::model::Flux::new(&cfg, vb)?;
|
||||
|
||||
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
|
||||
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
|
||||
println!("{state:?}");
|
||||
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
|
||||
println!("{timesteps:?}");
|
||||
flux::sampling::denoise(
|
||||
&model,
|
||||
&state.img,
|
||||
&state.img_ids,
|
||||
&state.txt,
|
||||
&state.txt_ids,
|
||||
&state.vec,
|
||||
×teps,
|
||||
4.,
|
||||
)?
|
||||
};
|
||||
flux::sampling::unpack(&img, height, width)?
|
||||
}
|
||||
Some(file) => {
|
||||
let mut st = candle::safetensors::load(file, &device)?;
|
||||
st.remove("img").unwrap().to_dtype(dtype)?
|
||||
}
|
||||
};
|
||||
println!("latent img\n{img}");
|
||||
|
||||
let img = {
|
||||
let model_file = bf_repo.get("ae.sft")?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
|
||||
let cfg = flux::autoencoder::Config::schnell();
|
||||
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
|
||||
model.decode(&img)?
|
||||
};
|
||||
println!("img\n{img}");
|
||||
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
|
||||
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
run(args)
|
||||
}
|
440
candle-transformers/src/models/flux/autoencoder.rs
Normal file
440
candle-transformers/src/models/flux/autoencoder.rs
Normal file
@ -0,0 +1,440 @@
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder};
|
||||
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/modules/autoencoder.py#L9
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub resolution: usize,
|
||||
pub in_channels: usize,
|
||||
pub ch: usize,
|
||||
pub out_ch: usize,
|
||||
pub ch_mult: Vec<usize>,
|
||||
pub num_res_blocks: usize,
|
||||
pub z_channels: usize,
|
||||
pub scale_factor: f64,
|
||||
pub shift_factor: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L47
|
||||
pub fn dev() -> Self {
|
||||
Self {
|
||||
resolution: 256,
|
||||
in_channels: 3,
|
||||
ch: 128,
|
||||
out_ch: 3,
|
||||
ch_mult: vec![1, 2, 4, 4],
|
||||
num_res_blocks: 2,
|
||||
z_channels: 16,
|
||||
scale_factor: 0.3611,
|
||||
shift_factor: 0.1159,
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L79
|
||||
pub fn schnell() -> Self {
|
||||
Self {
|
||||
resolution: 256,
|
||||
in_channels: 3,
|
||||
ch: 128,
|
||||
out_ch: 3,
|
||||
ch_mult: vec![1, 2, 4, 4],
|
||||
num_res_blocks: 2,
|
||||
z_channels: 16,
|
||||
scale_factor: 0.3611,
|
||||
shift_factor: 0.1159,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let dim = q.dim(D::Minus1)?;
|
||||
let scale_factor = 1.0 / (dim as f64).sqrt();
|
||||
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
|
||||
candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct AttnBlock {
|
||||
q: Conv2d,
|
||||
k: Conv2d,
|
||||
v: Conv2d,
|
||||
proj_out: Conv2d,
|
||||
norm: GroupNorm,
|
||||
}
|
||||
|
||||
impl AttnBlock {
|
||||
fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
|
||||
let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
|
||||
let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
|
||||
let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
|
||||
let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?;
|
||||
Ok(Self {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
proj_out,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for AttnBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let init_xs = xs;
|
||||
let xs = xs.apply(&self.norm)?;
|
||||
let q = xs.apply(&self.q)?;
|
||||
let k = xs.apply(&self.k)?;
|
||||
let v = xs.apply(&self.v)?;
|
||||
let (b, c, h, w) = q.dims4()?;
|
||||
let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
|
||||
let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
|
||||
let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
|
||||
let xs = scaled_dot_product_attention(&q, &k, &v)?;
|
||||
let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
|
||||
xs.apply(&self.proj_out)? + init_xs
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ResnetBlock {
|
||||
norm1: GroupNorm,
|
||||
conv1: Conv2d,
|
||||
norm2: GroupNorm,
|
||||
conv2: Conv2d,
|
||||
nin_shortcut: Option<Conv2d>,
|
||||
}
|
||||
|
||||
impl ResnetBlock {
|
||||
fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let conv_cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?;
|
||||
let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
|
||||
let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?;
|
||||
let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
|
||||
let nin_shortcut = if in_c == out_c {
|
||||
None
|
||||
} else {
|
||||
Some(conv2d(
|
||||
in_c,
|
||||
out_c,
|
||||
1,
|
||||
Default::default(),
|
||||
vb.pp("nin_shortcut"),
|
||||
)?)
|
||||
};
|
||||
Ok(Self {
|
||||
norm1,
|
||||
conv1,
|
||||
norm2,
|
||||
conv2,
|
||||
nin_shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for ResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let h = xs
|
||||
.apply(&self.norm1)?
|
||||
.apply(&candle_nn::Activation::Swish)?
|
||||
.apply(&self.conv1)?
|
||||
.apply(&self.norm2)?
|
||||
.apply(&candle_nn::Activation::Swish)?
|
||||
.apply(&self.conv2)?;
|
||||
match self.nin_shortcut.as_ref() {
|
||||
None => xs + h,
|
||||
Some(c) => xs.apply(c)? + h,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Downsample {
|
||||
conv: Conv2d,
|
||||
}
|
||||
|
||||
impl Downsample {
|
||||
fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let conv_cfg = candle_nn::Conv2dConfig {
|
||||
stride: 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
|
||||
Ok(Self { conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Downsample {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
|
||||
let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
|
||||
xs.apply(&self.conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Upsample {
|
||||
conv: Conv2d,
|
||||
}
|
||||
|
||||
impl Upsample {
|
||||
fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let conv_cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
|
||||
Ok(Self { conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Upsample {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_, _, h, w) = xs.dims4()?;
|
||||
xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DownBlock {
|
||||
block: Vec<ResnetBlock>,
|
||||
downsample: Option<Downsample>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encoder {
|
||||
conv_in: Conv2d,
|
||||
mid_block_1: ResnetBlock,
|
||||
mid_attn_1: AttnBlock,
|
||||
mid_block_2: ResnetBlock,
|
||||
norm_out: GroupNorm,
|
||||
conv_out: Conv2d,
|
||||
down: Vec<DownBlock>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let conv_cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let mut block_in = cfg.ch;
|
||||
let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
|
||||
|
||||
let mut down = Vec::with_capacity(cfg.ch_mult.len());
|
||||
let vb_d = vb.pp("down");
|
||||
for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() {
|
||||
let mut block = Vec::with_capacity(cfg.num_res_blocks);
|
||||
let vb_d = vb_d.pp(i_level);
|
||||
let vb_b = vb_d.pp("block");
|
||||
let in_ch_mult = if i_level == 0 {
|
||||
1
|
||||
} else {
|
||||
cfg.ch_mult[i_level - 1]
|
||||
};
|
||||
block_in = cfg.ch * in_ch_mult;
|
||||
let block_out = cfg.ch * ch_mult;
|
||||
for i_block in 0..cfg.num_res_blocks {
|
||||
let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
|
||||
block.push(b);
|
||||
block_in = block_out;
|
||||
}
|
||||
let downsample = if i_level != cfg.ch_mult.len() - 1 {
|
||||
Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let block = DownBlock { block, downsample };
|
||||
down.push(block)
|
||||
}
|
||||
|
||||
let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
|
||||
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
|
||||
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
|
||||
let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?;
|
||||
let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
mid_block_1,
|
||||
mid_attn_1,
|
||||
mid_block_2,
|
||||
norm_out,
|
||||
conv_out,
|
||||
down,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle_nn::Module for Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut h = xs.apply(&self.conv_in)?;
|
||||
for block in self.down.iter() {
|
||||
for b in block.block.iter() {
|
||||
h = h.apply(b)?
|
||||
}
|
||||
if let Some(ds) = block.downsample.as_ref() {
|
||||
h = h.apply(ds)?
|
||||
}
|
||||
}
|
||||
h.apply(&self.mid_block_1)?
|
||||
.apply(&self.mid_attn_1)?
|
||||
.apply(&self.mid_block_2)?
|
||||
.apply(&self.norm_out)?
|
||||
.apply(&candle_nn::Activation::Swish)?
|
||||
.apply(&self.conv_out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct UpBlock {
|
||||
block: Vec<ResnetBlock>,
|
||||
upsample: Option<Upsample>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Decoder {
|
||||
conv_in: Conv2d,
|
||||
mid_block_1: ResnetBlock,
|
||||
mid_attn_1: AttnBlock,
|
||||
mid_block_2: ResnetBlock,
|
||||
norm_out: GroupNorm,
|
||||
conv_out: Conv2d,
|
||||
up: Vec<UpBlock>,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let conv_cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1);
|
||||
let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
|
||||
let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
|
||||
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
|
||||
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
|
||||
|
||||
let mut up = Vec::with_capacity(cfg.ch_mult.len());
|
||||
let vb_u = vb.pp("up");
|
||||
for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() {
|
||||
let block_out = cfg.ch * ch_mult;
|
||||
let vb_u = vb_u.pp(i_level);
|
||||
let vb_b = vb_u.pp("block");
|
||||
let mut block = Vec::with_capacity(cfg.num_res_blocks + 1);
|
||||
for i_block in 0..=cfg.num_res_blocks {
|
||||
let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
|
||||
block.push(b);
|
||||
block_in = block_out;
|
||||
}
|
||||
let upsample = if i_level != 0 {
|
||||
Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let block = UpBlock { block, upsample };
|
||||
up.push(block)
|
||||
}
|
||||
up.reverse();
|
||||
|
||||
let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
|
||||
let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?;
|
||||
Ok(Self {
|
||||
conv_in,
|
||||
mid_block_1,
|
||||
mid_attn_1,
|
||||
mid_block_2,
|
||||
norm_out,
|
||||
conv_out,
|
||||
up,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle_nn::Module for Decoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let h = xs.apply(&self.conv_in)?;
|
||||
let mut h = h
|
||||
.apply(&self.mid_block_1)?
|
||||
.apply(&self.mid_attn_1)?
|
||||
.apply(&self.mid_block_2)?;
|
||||
for block in self.up.iter().rev() {
|
||||
for b in block.block.iter() {
|
||||
h = h.apply(b)?
|
||||
}
|
||||
if let Some(us) = block.upsample.as_ref() {
|
||||
h = h.apply(us)?
|
||||
}
|
||||
}
|
||||
h.apply(&self.norm_out)?
|
||||
.apply(&candle_nn::Activation::Swish)?
|
||||
.apply(&self.conv_out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DiagonalGaussian {
|
||||
sample: bool,
|
||||
chunk_dim: usize,
|
||||
}
|
||||
|
||||
impl DiagonalGaussian {
|
||||
pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
|
||||
Ok(Self { sample, chunk_dim })
|
||||
}
|
||||
}
|
||||
|
||||
impl candle_nn::Module for DiagonalGaussian {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let chunks = xs.chunk(2, self.chunk_dim)?;
|
||||
if self.sample {
|
||||
let std = (&chunks[1] * 0.5)?.exp()?;
|
||||
&chunks[0] + (std * chunks[0].randn_like(0., 1.))?
|
||||
} else {
|
||||
Ok(chunks[0].clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoEncoder {
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
reg: DiagonalGaussian,
|
||||
shift_factor: f64,
|
||||
scale_factor: f64,
|
||||
}
|
||||
|
||||
impl AutoEncoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||
let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
|
||||
let reg = DiagonalGaussian::new(true, 1)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
reg,
|
||||
scale_factor: cfg.scale_factor,
|
||||
shift_factor: cfg.shift_factor,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
|
||||
(z - self.shift_factor)? * self.scale_factor
|
||||
}
|
||||
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
|
||||
xs.apply(&self.decoder)
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for AutoEncoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.decode(&self.encode(xs)?)
|
||||
}
|
||||
}
|
3
candle-transformers/src/models/flux/mod.rs
Normal file
3
candle-transformers/src/models/flux/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub mod autoencoder;
|
||||
pub mod model;
|
||||
pub mod sampling;
|
582
candle-transformers/src/models/flux/model.rs
Normal file
582
candle-transformers/src/models/flux/model.rs
Normal file
@ -0,0 +1,582 @@
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};
|
||||
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub in_channels: usize,
|
||||
pub vec_in_dim: usize,
|
||||
pub context_in_dim: usize,
|
||||
pub hidden_size: usize,
|
||||
pub mlp_ratio: f64,
|
||||
pub num_heads: usize,
|
||||
pub depth: usize,
|
||||
pub depth_single_blocks: usize,
|
||||
pub axes_dim: Vec<usize>,
|
||||
pub theta: usize,
|
||||
pub qkv_bias: bool,
|
||||
pub guidance_embed: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
|
||||
pub fn dev() -> Self {
|
||||
Self {
|
||||
in_channels: 64,
|
||||
vec_in_dim: 768,
|
||||
context_in_dim: 4096,
|
||||
hidden_size: 3072,
|
||||
mlp_ratio: 4.0,
|
||||
num_heads: 24,
|
||||
depth: 19,
|
||||
depth_single_blocks: 38,
|
||||
axes_dim: vec![16, 56, 56],
|
||||
theta: 10_000,
|
||||
qkv_bias: true,
|
||||
guidance_embed: true,
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64
|
||||
pub fn schnell() -> Self {
|
||||
Self {
|
||||
in_channels: 64,
|
||||
vec_in_dim: 768,
|
||||
context_in_dim: 4096,
|
||||
hidden_size: 3072,
|
||||
mlp_ratio: 4.0,
|
||||
num_heads: 24,
|
||||
depth: 19,
|
||||
depth_single_blocks: 38,
|
||||
axes_dim: vec![16, 56, 56],
|
||||
theta: 10_000,
|
||||
qkv_bias: true,
|
||||
guidance_embed: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
|
||||
Ok(LayerNorm::new_no_bias(ws, 1e-6))
|
||||
}
|
||||
|
||||
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
|
||||
let dim = q.dim(D::Minus1)?;
|
||||
let scale_factor = 1.0 / (dim as f64).sqrt();
|
||||
let mut batch_dims = q.dims().to_vec();
|
||||
batch_dims.pop();
|
||||
batch_dims.pop();
|
||||
let q = q.flatten_to(batch_dims.len() - 1)?;
|
||||
let k = k.flatten_to(batch_dims.len() - 1)?;
|
||||
let v = v.flatten_to(batch_dims.len() - 1)?;
|
||||
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
|
||||
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
|
||||
batch_dims.push(attn_scores.dim(D::Minus2)?);
|
||||
batch_dims.push(attn_scores.dim(D::Minus1)?);
|
||||
attn_scores.reshape(batch_dims)
|
||||
}
|
||||
|
||||
fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
|
||||
if dim % 2 == 1 {
|
||||
candle::bail!("dim {dim} is odd")
|
||||
}
|
||||
let dev = pos.device();
|
||||
let theta = theta as f64;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
|
||||
let inv_freq = inv_freq.to_dtype(pos.dtype())?;
|
||||
let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
|
||||
let cos = freqs.cos()?;
|
||||
let sin = freqs.sin()?;
|
||||
let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
|
||||
let (b, n, d, _ij) = out.dims4()?;
|
||||
out.reshape((b, n, d, 2, 2))
|
||||
}
|
||||
|
||||
fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
|
||||
let dims = x.dims();
|
||||
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
|
||||
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
|
||||
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
|
||||
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
|
||||
}
|
||||
|
||||
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let q = apply_rope(q, pe)?.contiguous()?;
|
||||
let k = apply_rope(k, pe)?.contiguous()?;
|
||||
let x = scaled_dot_product_attention(&q, &k, v)?;
|
||||
x.transpose(1, 2)?.flatten_from(2)
|
||||
}
|
||||
|
||||
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
|
||||
const TIME_FACTOR: f64 = 1000.;
|
||||
const MAX_PERIOD: f64 = 10000.;
|
||||
if dim % 2 == 1 {
|
||||
candle::bail!("{dim} is odd")
|
||||
}
|
||||
let dev = t.device();
|
||||
let half = dim / 2;
|
||||
let t = (t * TIME_FACTOR)?;
|
||||
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
|
||||
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
|
||||
let args = t
|
||||
.unsqueeze(1)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.broadcast_mul(&freqs.unsqueeze(0)?)?;
|
||||
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
|
||||
Ok(emb)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EmbedNd {
|
||||
#[allow(unused)]
|
||||
dim: usize,
|
||||
theta: usize,
|
||||
axes_dim: Vec<usize>,
|
||||
}
|
||||
|
||||
impl EmbedNd {
|
||||
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
|
||||
Self {
|
||||
dim,
|
||||
theta,
|
||||
axes_dim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for EmbedNd {
|
||||
fn forward(&self, ids: &Tensor) -> Result<Tensor> {
|
||||
let n_axes = ids.dim(D::Minus1)?;
|
||||
let mut emb = Vec::with_capacity(n_axes);
|
||||
for idx in 0..n_axes {
|
||||
let r = rope(
|
||||
&ids.get_on_dim(D::Minus1, idx)?,
|
||||
self.axes_dim[idx],
|
||||
self.theta,
|
||||
)?;
|
||||
emb.push(r)
|
||||
}
|
||||
let emb = Tensor::cat(&emb, 2)?;
|
||||
emb.unsqueeze(1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MlpEmbedder {
|
||||
in_layer: Linear,
|
||||
out_layer: Linear,
|
||||
}
|
||||
|
||||
impl MlpEmbedder {
|
||||
fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?;
|
||||
let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?;
|
||||
Ok(Self {
|
||||
in_layer,
|
||||
out_layer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for MlpEmbedder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QkNorm {
|
||||
query_norm: RmsNorm,
|
||||
key_norm: RmsNorm,
|
||||
}
|
||||
|
||||
impl QkNorm {
|
||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let query_norm = vb.get(dim, "query_norm.scale")?;
|
||||
let query_norm = RmsNorm::new(query_norm, 1e-6);
|
||||
let key_norm = vb.get(dim, "key_norm.scale")?;
|
||||
let key_norm = RmsNorm::new(key_norm, 1e-6);
|
||||
Ok(Self {
|
||||
query_norm,
|
||||
key_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Modulation {
|
||||
lin: Linear,
|
||||
multiplier: usize,
|
||||
}
|
||||
|
||||
impl Modulation {
|
||||
fn new(dim: usize, double: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let multiplier = if double { 6 } else { 3 };
|
||||
let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?;
|
||||
Ok(Self { lin, multiplier })
|
||||
}
|
||||
|
||||
fn forward(&self, vec_: &Tensor) -> Result<Vec<Tensor>> {
|
||||
vec_.silu()?
|
||||
.apply(&self.lin)?
|
||||
.unsqueeze(1)?
|
||||
.chunk(self.multiplier, D::Minus1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SelfAttention {
|
||||
qkv: Linear,
|
||||
norm: QkNorm,
|
||||
proj: Linear,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let head_dim = dim / num_heads;
|
||||
let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
|
||||
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
|
||||
let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?;
|
||||
Ok(Self {
|
||||
qkv,
|
||||
norm,
|
||||
proj,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let qkv = xs.apply(&self.qkv)?;
|
||||
let (b, l, _khd) = qkv.dims3()?;
|
||||
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
|
||||
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
|
||||
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
||||
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
||||
let q = q.apply(&self.norm.query_norm)?;
|
||||
let k = k.apply(&self.norm.key_norm)?;
|
||||
Ok((q, k, v))
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let (q, k, v) = self.qkv(xs)?;
|
||||
attention(&q, &k, &v, pe)?.apply(&self.proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?;
|
||||
let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?;
|
||||
Ok(Self { lin1, lin2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DoubleStreamBlock {
|
||||
img_mod: Modulation,
|
||||
img_norm1: LayerNorm,
|
||||
img_attn: SelfAttention,
|
||||
img_norm2: LayerNorm,
|
||||
img_mlp: Mlp,
|
||||
txt_mod: Modulation,
|
||||
txt_norm1: LayerNorm,
|
||||
txt_attn: SelfAttention,
|
||||
txt_norm2: LayerNorm,
|
||||
txt_mlp: Mlp,
|
||||
}
|
||||
|
||||
impl DoubleStreamBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h_sz = cfg.hidden_size;
|
||||
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
|
||||
let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?;
|
||||
let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
|
||||
let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
|
||||
let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
|
||||
let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
|
||||
let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?;
|
||||
let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
|
||||
let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
|
||||
let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
|
||||
let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
|
||||
Ok(Self {
|
||||
img_mod,
|
||||
img_norm1,
|
||||
img_attn,
|
||||
img_norm2,
|
||||
img_mlp,
|
||||
txt_mod,
|
||||
txt_norm1,
|
||||
txt_attn,
|
||||
txt_norm2,
|
||||
txt_mlp,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
txt: &Tensor,
|
||||
vec_: &Tensor,
|
||||
pe: &Tensor,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate
|
||||
let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate
|
||||
let img_modulated = img.apply(&self.img_norm1)?;
|
||||
let img_modulated = img_modulated
|
||||
.broadcast_mul(&(&img_mod[1] + 1.)?)?
|
||||
.broadcast_add(&img_mod[0])?;
|
||||
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
|
||||
|
||||
let txt_modulated = txt.apply(&self.txt_norm1)?;
|
||||
let txt_modulated = txt_modulated
|
||||
.broadcast_mul(&(&txt_mod[1] + 1.)?)?
|
||||
.broadcast_add(&txt_mod[0])?;
|
||||
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
|
||||
|
||||
let q = Tensor::cat(&[txt_q, img_q], 2)?;
|
||||
let k = Tensor::cat(&[txt_k, img_k], 2)?;
|
||||
let v = Tensor::cat(&[txt_v, img_v], 2)?;
|
||||
|
||||
let attn = attention(&q, &k, &v, pe)?;
|
||||
let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
|
||||
let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
|
||||
|
||||
let img = (img
|
||||
+ img_attn
|
||||
.apply(&self.img_attn.proj)?
|
||||
.broadcast_mul(&img_mod[2]))?;
|
||||
let img = (&img
|
||||
+ &img_mod[5].broadcast_mul(
|
||||
&img.apply(&self.img_norm2)?
|
||||
.broadcast_mul(&(&img_mod[4] + 1.0)?)?
|
||||
.broadcast_add(&img_mod[3])?
|
||||
.apply(&self.img_mlp)?,
|
||||
)?)?;
|
||||
|
||||
let txt = (txt
|
||||
+ txt_attn
|
||||
.apply(&self.txt_attn.proj)?
|
||||
.broadcast_mul(&txt_mod[2]))?;
|
||||
let txt = (&txt
|
||||
+ &txt_mod[5].broadcast_mul(
|
||||
&txt.apply(&self.txt_norm2)?
|
||||
.broadcast_mul(&(&txt_mod[4] + 1.0)?)?
|
||||
.broadcast_add(&txt_mod[3])?
|
||||
.apply(&self.txt_mlp)?,
|
||||
)?)?;
|
||||
|
||||
Ok((img, txt))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SingleStreamBlock {
|
||||
linear1: Linear,
|
||||
linear2: Linear,
|
||||
norm: QkNorm,
|
||||
pre_norm: LayerNorm,
|
||||
modulation: Modulation,
|
||||
h_sz: usize,
|
||||
mlp_sz: usize,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl SingleStreamBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h_sz = cfg.hidden_size;
|
||||
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
|
||||
let head_dim = h_sz / cfg.num_heads;
|
||||
let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
|
||||
let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
|
||||
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
|
||||
let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
|
||||
let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?;
|
||||
Ok(Self {
|
||||
linear1,
|
||||
linear2,
|
||||
norm,
|
||||
pre_norm,
|
||||
modulation,
|
||||
h_sz,
|
||||
mlp_sz,
|
||||
num_heads: cfg.num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
|
||||
let mod_ = self.modulation.forward(vec_)?;
|
||||
let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]);
|
||||
let x_mod = xs
|
||||
.apply(&self.pre_norm)?
|
||||
.broadcast_mul(&(scale + 1.0)?)?
|
||||
.broadcast_add(shift)?;
|
||||
let x_mod = x_mod.apply(&self.linear1)?;
|
||||
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
|
||||
let (b, l, _khd) = qkv.dims3()?;
|
||||
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
|
||||
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
|
||||
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
||||
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
||||
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
|
||||
let q = q.apply(&self.norm.query_norm)?;
|
||||
let k = k.apply(&self.norm.key_norm)?;
|
||||
let attn = attention(&q, &k, &v, pe)?;
|
||||
let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
|
||||
xs + gate.broadcast_mul(&output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LastLayer {
|
||||
norm_final: LayerNorm,
|
||||
linear: Linear,
|
||||
ada_ln_modulation: Linear,
|
||||
}
|
||||
|
||||
impl LastLayer {
|
||||
fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
|
||||
let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
|
||||
let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
|
||||
Ok(Self {
|
||||
norm_final,
|
||||
linear,
|
||||
ada_ln_modulation,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
|
||||
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
|
||||
let (shift, scale) = (&chunks[0], &chunks[1]);
|
||||
let xs = xs
|
||||
.apply(&self.norm_final)?
|
||||
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
|
||||
.broadcast_add(&shift.unsqueeze(1)?)?;
|
||||
xs.apply(&self.linear)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Flux {
|
||||
img_in: Linear,
|
||||
txt_in: Linear,
|
||||
time_in: MlpEmbedder,
|
||||
vector_in: MlpEmbedder,
|
||||
guidance_in: Option<MlpEmbedder>,
|
||||
pe_embedder: EmbedNd,
|
||||
double_blocks: Vec<DoubleStreamBlock>,
|
||||
single_blocks: Vec<SingleStreamBlock>,
|
||||
final_layer: LastLayer,
|
||||
}
|
||||
|
||||
impl Flux {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let img_in = candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
|
||||
let txt_in = candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
|
||||
let mut double_blocks = Vec::with_capacity(cfg.depth);
|
||||
let vb_d = vb.pp("double_blocks");
|
||||
for idx in 0..cfg.depth {
|
||||
let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
|
||||
double_blocks.push(db)
|
||||
}
|
||||
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
|
||||
let vb_s = vb.pp("single_blocks");
|
||||
for idx in 0..cfg.depth_single_blocks {
|
||||
let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
|
||||
single_blocks.push(sb)
|
||||
}
|
||||
let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
|
||||
let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
|
||||
let guidance_in = if cfg.guidance_embed {
|
||||
let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
|
||||
Some(mlp)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let final_layer =
|
||||
LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
|
||||
let pe_dim = cfg.hidden_size / cfg.num_heads;
|
||||
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
|
||||
Ok(Self {
|
||||
img_in,
|
||||
txt_in,
|
||||
time_in,
|
||||
vector_in,
|
||||
guidance_in,
|
||||
pe_embedder,
|
||||
double_blocks,
|
||||
single_blocks,
|
||||
final_layer,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
txt: &Tensor,
|
||||
txt_ids: &Tensor,
|
||||
timesteps: &Tensor,
|
||||
y: &Tensor,
|
||||
guidance: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
if txt.rank() != 3 {
|
||||
candle::bail!("unexpected shape for txt {:?}", txt.shape())
|
||||
}
|
||||
if img.rank() != 3 {
|
||||
candle::bail!("unexpected shape for img {:?}", img.shape())
|
||||
}
|
||||
let dtype = img.dtype();
|
||||
let pe = {
|
||||
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
|
||||
ids.apply(&self.pe_embedder)?
|
||||
};
|
||||
let mut txt = txt.apply(&self.txt_in)?;
|
||||
let mut img = img.apply(&self.img_in)?;
|
||||
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
|
||||
let vec_ = match (self.guidance_in.as_ref(), guidance) {
|
||||
(Some(g_in), Some(guidance)) => {
|
||||
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
|
||||
}
|
||||
_ => vec_,
|
||||
};
|
||||
let vec_ = (vec_ + y.apply(&self.vector_in))?;
|
||||
|
||||
// Double blocks
|
||||
for block in self.double_blocks.iter() {
|
||||
(img, txt) = block.forward(&img, &txt, &vec_, &pe)?
|
||||
}
|
||||
// Single blocks
|
||||
let mut img = Tensor::cat(&[&txt, &img], 1)?;
|
||||
for block in self.single_blocks.iter() {
|
||||
img = block.forward(&img, &vec_, &pe)?;
|
||||
}
|
||||
let img = img.i((.., txt.dim(1)?..))?;
|
||||
self.final_layer.forward(&img, &vec_)
|
||||
}
|
||||
}
|
119
candle-transformers/src/models/flux/sampling.rs
Normal file
119
candle-transformers/src/models/flux/sampling.rs
Normal file
@ -0,0 +1,119 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
||||
pub fn get_noise(
|
||||
num_samples: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let height = (height + 15) / 16 * 2;
|
||||
let width = (width + 15) / 16 * 2;
|
||||
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct State {
|
||||
pub img: Tensor,
|
||||
pub img_ids: Tensor,
|
||||
pub txt: Tensor,
|
||||
pub txt_ids: Tensor,
|
||||
pub vec: Tensor,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(t5_emb: &Tensor, clip_emb: &Tensor, img: &Tensor) -> Result<Self> {
|
||||
let dtype = img.dtype();
|
||||
let (bs, c, h, w) = img.dims4()?;
|
||||
let dev = img.device();
|
||||
let img = img.reshape((bs, c, h / 2, 2, w / 2, 2))?; // (b, c, h, ph, w, pw)
|
||||
let img = img.permute((0, 2, 4, 1, 3, 5))?; // (b, h, w, c, ph, pw)
|
||||
let img = img.reshape((bs, h / 2 * w / 2, c * 4))?;
|
||||
let img_ids = Tensor::stack(
|
||||
&[
|
||||
Tensor::full(0u32, (h / 2, w / 2), dev)?,
|
||||
Tensor::arange(0u32, h as u32 / 2, dev)?
|
||||
.reshape(((), 1))?
|
||||
.broadcast_as((h / 2, w / 2))?,
|
||||
Tensor::arange(0u32, w as u32 / 2, dev)?
|
||||
.reshape((1, ()))?
|
||||
.broadcast_as((h / 2, w / 2))?,
|
||||
],
|
||||
2,
|
||||
)?
|
||||
.to_dtype(dtype)?;
|
||||
let img_ids = img_ids.reshape((1, h / 2 * w / 2, 3))?;
|
||||
let img_ids = img_ids.repeat((bs, 1, 1))?;
|
||||
let txt = t5_emb.repeat(bs)?;
|
||||
let txt_ids = Tensor::zeros((bs, txt.dim(1)?, 3), dtype, dev)?;
|
||||
let vec = clip_emb.repeat(bs)?;
|
||||
Ok(Self {
|
||||
img,
|
||||
img_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 {
|
||||
let e = mu.exp();
|
||||
e / (e + (1. / t - 1.).powf(sigma))
|
||||
}
|
||||
|
||||
/// `shift` is a triple `(image_seq_len, base_shift, max_shift)`.
|
||||
pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f64> {
|
||||
let timesteps: Vec<f64> = (0..=num_steps)
|
||||
.map(|v| v as f64 / num_steps as f64)
|
||||
.rev()
|
||||
.collect();
|
||||
match shift {
|
||||
None => timesteps,
|
||||
Some((image_seq_len, y1, y2)) => {
|
||||
let (x1, x2) = (256., 4096.);
|
||||
let m = (y2 - y1) / (x2 - x1);
|
||||
let b = y1 - m * x1;
|
||||
let mu = m * image_seq_len as f64 + b;
|
||||
timesteps
|
||||
.into_iter()
|
||||
.map(|v| time_shift(mu, 1., v))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
||||
let height = (height + 15) / 16;
|
||||
let width = (width + 15) / 16;
|
||||
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
||||
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
||||
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn denoise(
|
||||
model: &super::model::Flux,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
txt: &Tensor,
|
||||
txt_ids: &Tensor,
|
||||
vec_: &Tensor,
|
||||
timesteps: &[f64],
|
||||
guidance: f64,
|
||||
) -> Result<Tensor> {
|
||||
let b_sz = img.dim(0)?;
|
||||
let dev = img.device();
|
||||
let guidance = Tensor::full(guidance as f32, b_sz, dev)?;
|
||||
let mut img = img.clone();
|
||||
for window in timesteps.windows(2) {
|
||||
let (t_curr, t_prev) = match window {
|
||||
[a, b] => (a, b),
|
||||
_ => continue,
|
||||
};
|
||||
let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?;
|
||||
let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?;
|
||||
img = (img + pred * (t_prev - t_curr))?
|
||||
}
|
||||
Ok(img)
|
||||
}
|
@ -17,6 +17,7 @@ pub mod efficientvit;
|
||||
pub mod encodec;
|
||||
pub mod eva2;
|
||||
pub mod falcon;
|
||||
pub mod flux;
|
||||
pub mod gemma;
|
||||
pub mod hiera;
|
||||
pub mod jina_bert;
|
||||
|
Reference in New Issue
Block a user