Add the flux model for image generation. (#2390)

* Add the flux autoencoder.

* Add the encoder down-blocks.

* Upsampling in the decoder.

* Sketch the flow matching model.

* More flux model.

* Add some of the positional embeddings.

* Add the rope embeddings.

* Add the sampling functions.

* Add the flux example.

* Fix the T5 bits.

* Proper T5 tokenizer.

* Clip encoder path fix.

* Get the clip embeddings.

* No configurable weights in layer norm.

* More weights related fixes.

* Yet another shape fix.

* DType fix.

* Fix a couple more shape issues.

* DType fixes.

* Fix the latent dims.

* Fix more shape issues.

* Autoencoder fixes.

* Get some generations out.

* Bugfix.

* T5 padding.

* Clippy fix.

* Add the decode only mode.

* Fix.

* More fixes.

* Finally get some generations to work.

* Add readme.
This commit is contained in:
Laurent Mazare
2024-08-04 07:14:33 +01:00
committed by GitHub
parent 0fcb40b229
commit 19db6b9723
8 changed files with 1346 additions and 0 deletions

View File

@ -0,0 +1,19 @@
# candle-flux: image generation with latent rectified flow transformers
![rusty robot holding a candle](./assets/flux-robot.jpg)
Flux is a 12B rectified flow transformer capable of generating images from text
descriptions,
[huggingface](https://huggingface.co/black-forest-labs/FLUX.1-schnell),
[github](https://github.com/black-forest-labs/flux),
[blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
## Running the model
```bash
cargo run --features cuda --example flux -r -- \
--height 1024 --width 1024
--prompt "a rusty robot walking on a beach holding a small torch, the robot has the word "rust" written on it, high quality, 4k"
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

View File

@ -0,0 +1,182 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use candle_transformers::models::{clip, flux, t5};
use anyhow::{Error as E, Result};
use candle::{IndexOp, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use tokenizers::Tokenizer;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The prompt to be used for image generation.
#[arg(long, default_value = "A rusty robot walking on a beach")]
prompt: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
/// The width in pixels of the generated image.
#[arg(long)]
width: Option<usize>,
#[arg(long)]
decode_only: Option<String>,
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
cpu,
height,
width,
tracing,
decode_only,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let api = hf_hub::api::sync::Api::new()?;
let bf_repo = api.repo(hf_hub::Repo::model(
"black-forest-labs/FLUX.1-schnell".to_string(),
));
let device = candle_examples::device(cpu)?;
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
None => {
let t5_emb = {
let repo = api.repo(hf_hub::Repo::with_revision(
"google/t5-v1_1-xxl".to_string(),
hf_hub::RepoType::Model,
"refs/pr/2".to_string(),
));
let model_file = repo.get("model.safetensors")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config)?;
let mut model = t5::T5EncoderModel::load(vb, &config)?;
let tokenizer_filename = api
.model("lmz/mt5-tokenizers".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mut tokens = tokenizer
.encode(prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.resize(256, 0);
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
println!("{input_token_ids}");
model.forward(&input_token_ids)?
};
println!("T5\n{t5_emb}");
let clip_emb = {
let repo = api.repo(hf_hub::Repo::model(
"openai/clip-vit-large-patch14".to_string(),
));
let model_file = repo.get("model.safetensors")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
// https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
let config = clip::text_model::ClipTextConfig {
vocab_size: 49408,
projection_dim: 768,
activation: clip::text_model::Activation::QuickGelu,
intermediate_size: 3072,
embed_dim: 768,
max_position_embeddings: 77,
pad_with: None,
num_hidden_layers: 12,
num_attention_heads: 12,
};
let model =
clip::text_model::ClipTextTransformer::new(vb.pp("text_model"), &config)?;
let tokenizer_filename = repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt.as_str(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let input_token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
println!("{input_token_ids}");
model.forward(&input_token_ids)?
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = bf_repo.get("flux1-schnell.sft")?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::model::Config::schnell();
let model = flux::model::Flux::new(&cfg, vb)?;
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
println!("{state:?}");
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
&timesteps,
4.,
)?
};
flux::sampling::unpack(&img, height, width)?
}
Some(file) => {
let mut st = candle::safetensors::load(file, &device)?;
st.remove("img").unwrap().to_dtype(dtype)?
}
};
println!("latent img\n{img}");
let img = {
let model_file = bf_repo.get("ae.sft")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::autoencoder::Config::schnell();
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
model.decode(&img)?
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
run(args)
}

View File

@ -0,0 +1,440 @@
use candle::{Result, Tensor, D};
use candle_nn::{conv2d, group_norm, Conv2d, GroupNorm, VarBuilder};
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/modules/autoencoder.py#L9
#[derive(Debug, Clone)]
pub struct Config {
pub resolution: usize,
pub in_channels: usize,
pub ch: usize,
pub out_ch: usize,
pub ch_mult: Vec<usize>,
pub num_res_blocks: usize,
pub z_channels: usize,
pub scale_factor: f64,
pub shift_factor: f64,
}
impl Config {
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L47
pub fn dev() -> Self {
Self {
resolution: 256,
in_channels: 3,
ch: 128,
out_ch: 3,
ch_mult: vec![1, 2, 4, 4],
num_res_blocks: 2,
z_channels: 16,
scale_factor: 0.3611,
shift_factor: 0.1159,
}
}
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L79
pub fn schnell() -> Self {
Self {
resolution: 256,
in_channels: 3,
ch: 128,
out_ch: 3,
ch_mult: vec![1, 2, 4, 4],
num_res_blocks: 2,
z_channels: 16,
scale_factor: 0.3611,
shift_factor: 0.1159,
}
}
}
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let dim = q.dim(D::Minus1)?;
let scale_factor = 1.0 / (dim as f64).sqrt();
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
}
#[derive(Debug, Clone)]
struct AttnBlock {
q: Conv2d,
k: Conv2d,
v: Conv2d,
proj_out: Conv2d,
norm: GroupNorm,
}
impl AttnBlock {
fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
let norm = group_norm(32, in_c, 1e-6, vb.pp("norm"))?;
Ok(Self {
q,
k,
v,
proj_out,
norm,
})
}
}
impl candle::Module for AttnBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let init_xs = xs;
let xs = xs.apply(&self.norm)?;
let q = xs.apply(&self.q)?;
let k = xs.apply(&self.k)?;
let v = xs.apply(&self.v)?;
let (b, c, h, w) = q.dims4()?;
let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
let xs = scaled_dot_product_attention(&q, &k, &v)?;
let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
xs.apply(&self.proj_out)? + init_xs
}
}
#[derive(Debug, Clone)]
struct ResnetBlock {
norm1: GroupNorm,
conv1: Conv2d,
norm2: GroupNorm,
conv2: Conv2d,
nin_shortcut: Option<Conv2d>,
}
impl ResnetBlock {
fn new(in_c: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
let conv_cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let norm1 = group_norm(32, in_c, 1e-6, vb.pp("norm1"))?;
let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
let norm2 = group_norm(32, out_c, 1e-6, vb.pp("norm2"))?;
let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
let nin_shortcut = if in_c == out_c {
None
} else {
Some(conv2d(
in_c,
out_c,
1,
Default::default(),
vb.pp("nin_shortcut"),
)?)
};
Ok(Self {
norm1,
conv1,
norm2,
conv2,
nin_shortcut,
})
}
}
impl candle::Module for ResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let h = xs
.apply(&self.norm1)?
.apply(&candle_nn::Activation::Swish)?
.apply(&self.conv1)?
.apply(&self.norm2)?
.apply(&candle_nn::Activation::Swish)?
.apply(&self.conv2)?;
match self.nin_shortcut.as_ref() {
None => xs + h,
Some(c) => xs.apply(c)? + h,
}
}
}
#[derive(Debug, Clone)]
struct Downsample {
conv: Conv2d,
}
impl Downsample {
fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
let conv_cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
Ok(Self { conv })
}
}
impl candle::Module for Downsample {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
xs.apply(&self.conv)
}
}
#[derive(Debug, Clone)]
struct Upsample {
conv: Conv2d,
}
impl Upsample {
fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
let conv_cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
Ok(Self { conv })
}
}
impl candle::Module for Upsample {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_, _, h, w) = xs.dims4()?;
xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
}
}
#[derive(Debug, Clone)]
struct DownBlock {
block: Vec<ResnetBlock>,
downsample: Option<Downsample>,
}
#[derive(Debug, Clone)]
pub struct Encoder {
conv_in: Conv2d,
mid_block_1: ResnetBlock,
mid_attn_1: AttnBlock,
mid_block_2: ResnetBlock,
norm_out: GroupNorm,
conv_out: Conv2d,
down: Vec<DownBlock>,
}
impl Encoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let conv_cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let mut block_in = cfg.ch;
let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
let mut down = Vec::with_capacity(cfg.ch_mult.len());
let vb_d = vb.pp("down");
for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() {
let mut block = Vec::with_capacity(cfg.num_res_blocks);
let vb_d = vb_d.pp(i_level);
let vb_b = vb_d.pp("block");
let in_ch_mult = if i_level == 0 {
1
} else {
cfg.ch_mult[i_level - 1]
};
block_in = cfg.ch * in_ch_mult;
let block_out = cfg.ch * ch_mult;
for i_block in 0..cfg.num_res_blocks {
let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
block.push(b);
block_in = block_out;
}
let downsample = if i_level != cfg.ch_mult.len() - 1 {
Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
} else {
None
};
let block = DownBlock { block, downsample };
down.push(block)
}
let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
let conv_out = conv2d(block_in, 2 * cfg.z_channels, 3, conv_cfg, vb.pp("conv_out"))?;
let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
Ok(Self {
conv_in,
mid_block_1,
mid_attn_1,
mid_block_2,
norm_out,
conv_out,
down,
})
}
}
impl candle_nn::Module for Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut h = xs.apply(&self.conv_in)?;
for block in self.down.iter() {
for b in block.block.iter() {
h = h.apply(b)?
}
if let Some(ds) = block.downsample.as_ref() {
h = h.apply(ds)?
}
}
h.apply(&self.mid_block_1)?
.apply(&self.mid_attn_1)?
.apply(&self.mid_block_2)?
.apply(&self.norm_out)?
.apply(&candle_nn::Activation::Swish)?
.apply(&self.conv_out)
}
}
#[derive(Debug, Clone)]
struct UpBlock {
block: Vec<ResnetBlock>,
upsample: Option<Upsample>,
}
#[derive(Debug, Clone)]
pub struct Decoder {
conv_in: Conv2d,
mid_block_1: ResnetBlock,
mid_attn_1: AttnBlock,
mid_block_2: ResnetBlock,
norm_out: GroupNorm,
conv_out: Conv2d,
up: Vec<UpBlock>,
}
impl Decoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let conv_cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let mut block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1);
let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
let mut up = Vec::with_capacity(cfg.ch_mult.len());
let vb_u = vb.pp("up");
for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate().rev() {
let block_out = cfg.ch * ch_mult;
let vb_u = vb_u.pp(i_level);
let vb_b = vb_u.pp("block");
let mut block = Vec::with_capacity(cfg.num_res_blocks + 1);
for i_block in 0..=cfg.num_res_blocks {
let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
block.push(b);
block_in = block_out;
}
let upsample = if i_level != 0 {
Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
} else {
None
};
let block = UpBlock { block, upsample };
up.push(block)
}
up.reverse();
let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?;
Ok(Self {
conv_in,
mid_block_1,
mid_attn_1,
mid_block_2,
norm_out,
conv_out,
up,
})
}
}
impl candle_nn::Module for Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let h = xs.apply(&self.conv_in)?;
let mut h = h
.apply(&self.mid_block_1)?
.apply(&self.mid_attn_1)?
.apply(&self.mid_block_2)?;
for block in self.up.iter().rev() {
for b in block.block.iter() {
h = h.apply(b)?
}
if let Some(us) = block.upsample.as_ref() {
h = h.apply(us)?
}
}
h.apply(&self.norm_out)?
.apply(&candle_nn::Activation::Swish)?
.apply(&self.conv_out)
}
}
#[derive(Debug, Clone)]
pub struct DiagonalGaussian {
sample: bool,
chunk_dim: usize,
}
impl DiagonalGaussian {
pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
Ok(Self { sample, chunk_dim })
}
}
impl candle_nn::Module for DiagonalGaussian {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let chunks = xs.chunk(2, self.chunk_dim)?;
if self.sample {
let std = (&chunks[1] * 0.5)?.exp()?;
&chunks[0] + (std * chunks[0].randn_like(0., 1.))?
} else {
Ok(chunks[0].clone())
}
}
}
#[derive(Debug, Clone)]
pub struct AutoEncoder {
encoder: Encoder,
decoder: Decoder,
reg: DiagonalGaussian,
shift_factor: f64,
scale_factor: f64,
}
impl AutoEncoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
let reg = DiagonalGaussian::new(true, 1)?;
Ok(Self {
encoder,
decoder,
reg,
scale_factor: cfg.scale_factor,
shift_factor: cfg.shift_factor,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
(z - self.shift_factor)? * self.scale_factor
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
xs.apply(&self.decoder)
}
}
impl candle::Module for AutoEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.decode(&self.encode(xs)?)
}
}

View File

@ -0,0 +1,3 @@
pub mod autoencoder;
pub mod model;
pub mod sampling;

View File

@ -0,0 +1,582 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{LayerNorm, Linear, RmsNorm, VarBuilder};
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/model.py#L12
#[derive(Debug, Clone)]
pub struct Config {
pub in_channels: usize,
pub vec_in_dim: usize,
pub context_in_dim: usize,
pub hidden_size: usize,
pub mlp_ratio: f64,
pub num_heads: usize,
pub depth: usize,
pub depth_single_blocks: usize,
pub axes_dim: Vec<usize>,
pub theta: usize,
pub qkv_bias: bool,
pub guidance_embed: bool,
}
impl Config {
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L32
pub fn dev() -> Self {
Self {
in_channels: 64,
vec_in_dim: 768,
context_in_dim: 4096,
hidden_size: 3072,
mlp_ratio: 4.0,
num_heads: 24,
depth: 19,
depth_single_blocks: 38,
axes_dim: vec![16, 56, 56],
theta: 10_000,
qkv_bias: true,
guidance_embed: true,
}
}
// https://github.com/black-forest-labs/flux/blob/727e3a71faf37390f318cf9434f0939653302b60/src/flux/util.py#L64
pub fn schnell() -> Self {
Self {
in_channels: 64,
vec_in_dim: 768,
context_in_dim: 4096,
hidden_size: 3072,
mlp_ratio: 4.0,
num_heads: 24,
depth: 19,
depth_single_blocks: 38,
axes_dim: vec![16, 56, 56],
theta: 10_000,
qkv_bias: true,
guidance_embed: false,
}
}
}
fn layer_norm(dim: usize, vb: VarBuilder) -> Result<LayerNorm> {
let ws = Tensor::ones(dim, vb.dtype(), vb.device())?;
Ok(LayerNorm::new_no_bias(ws, 1e-6))
}
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let dim = q.dim(D::Minus1)?;
let scale_factor = 1.0 / (dim as f64).sqrt();
let mut batch_dims = q.dims().to_vec();
batch_dims.pop();
batch_dims.pop();
let q = q.flatten_to(batch_dims.len() - 1)?;
let k = k.flatten_to(batch_dims.len() - 1)?;
let v = v.flatten_to(batch_dims.len() - 1)?;
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
let attn_scores = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v)?;
batch_dims.push(attn_scores.dim(D::Minus2)?);
batch_dims.push(attn_scores.dim(D::Minus1)?);
attn_scores.reshape(batch_dims)
}
fn rope(pos: &Tensor, dim: usize, theta: usize) -> Result<Tensor> {
if dim % 2 == 1 {
candle::bail!("dim {dim} is odd")
}
let dev = pos.device();
let theta = theta as f64;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, 1, inv_freq_len), dev)?;
let inv_freq = inv_freq.to_dtype(pos.dtype())?;
let freqs = pos.unsqueeze(2)?.broadcast_mul(&inv_freq)?;
let cos = freqs.cos()?;
let sin = freqs.sin()?;
let out = Tensor::stack(&[&cos, &sin.neg()?, &sin, &cos], 3)?;
let (b, n, d, _ij) = out.dims4()?;
out.reshape((b, n, d, 2, 2))
}
fn apply_rope(x: &Tensor, freq_cis: &Tensor) -> Result<Tensor> {
let dims = x.dims();
let (b_sz, n_head, seq_len, n_embd) = x.dims4()?;
let x = x.reshape((b_sz, n_head, seq_len, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let fr0 = freq_cis.get_on_dim(D::Minus1, 0)?;
let fr1 = freq_cis.get_on_dim(D::Minus1, 1)?;
(fr0.broadcast_mul(&x0)? + fr1.broadcast_mul(&x1)?)?.reshape(dims.to_vec())
}
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, pe: &Tensor) -> Result<Tensor> {
let q = apply_rope(q, pe)?.contiguous()?;
let k = apply_rope(k, pe)?.contiguous()?;
let x = scaled_dot_product_attention(&q, &k, v)?;
x.transpose(1, 2)?.flatten_from(2)
}
fn timestep_embedding(t: &Tensor, dim: usize, dtype: DType) -> Result<Tensor> {
const TIME_FACTOR: f64 = 1000.;
const MAX_PERIOD: f64 = 10000.;
if dim % 2 == 1 {
candle::bail!("{dim} is odd")
}
let dev = t.device();
let half = dim / 2;
let t = (t * TIME_FACTOR)?;
let arange = Tensor::arange(0, half as u32, dev)?.to_dtype(candle::DType::F32)?;
let freqs = (arange * (-MAX_PERIOD.ln() / half as f64))?.exp()?;
let args = t
.unsqueeze(1)?
.to_dtype(candle::DType::F32)?
.broadcast_mul(&freqs.unsqueeze(0)?)?;
let emb = Tensor::cat(&[args.cos()?, args.sin()?], D::Minus1)?.to_dtype(dtype)?;
Ok(emb)
}
#[derive(Debug, Clone)]
pub struct EmbedNd {
#[allow(unused)]
dim: usize,
theta: usize,
axes_dim: Vec<usize>,
}
impl EmbedNd {
fn new(dim: usize, theta: usize, axes_dim: Vec<usize>) -> Self {
Self {
dim,
theta,
axes_dim,
}
}
}
impl candle::Module for EmbedNd {
fn forward(&self, ids: &Tensor) -> Result<Tensor> {
let n_axes = ids.dim(D::Minus1)?;
let mut emb = Vec::with_capacity(n_axes);
for idx in 0..n_axes {
let r = rope(
&ids.get_on_dim(D::Minus1, idx)?,
self.axes_dim[idx],
self.theta,
)?;
emb.push(r)
}
let emb = Tensor::cat(&emb, 2)?;
emb.unsqueeze(1)
}
}
#[derive(Debug, Clone)]
pub struct MlpEmbedder {
in_layer: Linear,
out_layer: Linear,
}
impl MlpEmbedder {
fn new(in_sz: usize, h_sz: usize, vb: VarBuilder) -> Result<Self> {
let in_layer = candle_nn::linear(in_sz, h_sz, vb.pp("in_layer"))?;
let out_layer = candle_nn::linear(h_sz, h_sz, vb.pp("out_layer"))?;
Ok(Self {
in_layer,
out_layer,
})
}
}
impl candle::Module for MlpEmbedder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.in_layer)?.silu()?.apply(&self.out_layer)
}
}
#[derive(Debug, Clone)]
pub struct QkNorm {
query_norm: RmsNorm,
key_norm: RmsNorm,
}
impl QkNorm {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let query_norm = vb.get(dim, "query_norm.scale")?;
let query_norm = RmsNorm::new(query_norm, 1e-6);
let key_norm = vb.get(dim, "key_norm.scale")?;
let key_norm = RmsNorm::new(key_norm, 1e-6);
Ok(Self {
query_norm,
key_norm,
})
}
}
#[derive(Debug, Clone)]
pub struct Modulation {
lin: Linear,
multiplier: usize,
}
impl Modulation {
fn new(dim: usize, double: bool, vb: VarBuilder) -> Result<Self> {
let multiplier = if double { 6 } else { 3 };
let lin = candle_nn::linear(dim, multiplier * dim, vb.pp("lin"))?;
Ok(Self { lin, multiplier })
}
fn forward(&self, vec_: &Tensor) -> Result<Vec<Tensor>> {
vec_.silu()?
.apply(&self.lin)?
.unsqueeze(1)?
.chunk(self.multiplier, D::Minus1)
}
}
#[derive(Debug, Clone)]
pub struct SelfAttention {
qkv: Linear,
norm: QkNorm,
proj: Linear,
num_heads: usize,
}
impl SelfAttention {
fn new(dim: usize, num_heads: usize, qkv_bias: bool, vb: VarBuilder) -> Result<Self> {
let head_dim = dim / num_heads;
let qkv = candle_nn::linear_b(dim, dim * 3, qkv_bias, vb.pp("qkv"))?;
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?;
Ok(Self {
qkv,
norm,
proj,
num_heads,
})
}
fn qkv(&self, xs: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let qkv = xs.apply(&self.qkv)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let q = q.apply(&self.norm.query_norm)?;
let k = k.apply(&self.norm.key_norm)?;
Ok((q, k, v))
}
#[allow(unused)]
fn forward(&self, xs: &Tensor, pe: &Tensor) -> Result<Tensor> {
let (q, k, v) = self.qkv(xs)?;
attention(&q, &k, &v, pe)?.apply(&self.proj)
}
}
#[derive(Debug, Clone)]
struct Mlp {
lin1: Linear,
lin2: Linear,
}
impl Mlp {
fn new(in_sz: usize, mlp_sz: usize, vb: VarBuilder) -> Result<Self> {
let lin1 = candle_nn::linear(in_sz, mlp_sz, vb.pp("0"))?;
let lin2 = candle_nn::linear(mlp_sz, in_sz, vb.pp("2"))?;
Ok(Self { lin1, lin2 })
}
}
impl candle::Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2)
}
}
#[derive(Debug, Clone)]
pub struct DoubleStreamBlock {
img_mod: Modulation,
img_norm1: LayerNorm,
img_attn: SelfAttention,
img_norm2: LayerNorm,
img_mlp: Mlp,
txt_mod: Modulation,
txt_norm1: LayerNorm,
txt_attn: SelfAttention,
txt_norm2: LayerNorm,
txt_mlp: Mlp,
}
impl DoubleStreamBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let img_mod = Modulation::new(h_sz, true, vb.pp("img_mod"))?;
let img_norm1 = layer_norm(h_sz, vb.pp("img_norm1"))?;
let img_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("img_attn"))?;
let img_norm2 = layer_norm(h_sz, vb.pp("img_norm2"))?;
let img_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("img_mlp"))?;
let txt_mod = Modulation::new(h_sz, true, vb.pp("txt_mod"))?;
let txt_norm1 = layer_norm(h_sz, vb.pp("txt_norm1"))?;
let txt_attn = SelfAttention::new(h_sz, cfg.num_heads, cfg.qkv_bias, vb.pp("txt_attn"))?;
let txt_norm2 = layer_norm(h_sz, vb.pp("txt_norm2"))?;
let txt_mlp = Mlp::new(h_sz, mlp_sz, vb.pp("txt_mlp"))?;
Ok(Self {
img_mod,
img_norm1,
img_attn,
img_norm2,
img_mlp,
txt_mod,
txt_norm1,
txt_attn,
txt_norm2,
txt_mlp,
})
}
fn forward(
&self,
img: &Tensor,
txt: &Tensor,
vec_: &Tensor,
pe: &Tensor,
) -> Result<(Tensor, Tensor)> {
let img_mod = self.img_mod.forward(vec_)?; // shift, scale, gate
let txt_mod = self.txt_mod.forward(vec_)?; // shift, scale, gate
let img_modulated = img.apply(&self.img_norm1)?;
let img_modulated = img_modulated
.broadcast_mul(&(&img_mod[1] + 1.)?)?
.broadcast_add(&img_mod[0])?;
let (img_q, img_k, img_v) = self.img_attn.qkv(&img_modulated)?;
let txt_modulated = txt.apply(&self.txt_norm1)?;
let txt_modulated = txt_modulated
.broadcast_mul(&(&txt_mod[1] + 1.)?)?
.broadcast_add(&txt_mod[0])?;
let (txt_q, txt_k, txt_v) = self.txt_attn.qkv(&txt_modulated)?;
let q = Tensor::cat(&[txt_q, img_q], 2)?;
let k = Tensor::cat(&[txt_k, img_k], 2)?;
let v = Tensor::cat(&[txt_v, img_v], 2)?;
let attn = attention(&q, &k, &v, pe)?;
let txt_attn = attn.narrow(1, 0, txt.dim(1)?)?;
let img_attn = attn.narrow(1, txt.dim(1)?, attn.dim(1)? - txt.dim(1)?)?;
let img = (img
+ img_attn
.apply(&self.img_attn.proj)?
.broadcast_mul(&img_mod[2]))?;
let img = (&img
+ &img_mod[5].broadcast_mul(
&img.apply(&self.img_norm2)?
.broadcast_mul(&(&img_mod[4] + 1.0)?)?
.broadcast_add(&img_mod[3])?
.apply(&self.img_mlp)?,
)?)?;
let txt = (txt
+ txt_attn
.apply(&self.txt_attn.proj)?
.broadcast_mul(&txt_mod[2]))?;
let txt = (&txt
+ &txt_mod[5].broadcast_mul(
&txt.apply(&self.txt_norm2)?
.broadcast_mul(&(&txt_mod[4] + 1.0)?)?
.broadcast_add(&txt_mod[3])?
.apply(&self.txt_mlp)?,
)?)?;
Ok((img, txt))
}
}
#[derive(Debug, Clone)]
pub struct SingleStreamBlock {
linear1: Linear,
linear2: Linear,
norm: QkNorm,
pre_norm: LayerNorm,
modulation: Modulation,
h_sz: usize,
mlp_sz: usize,
num_heads: usize,
}
impl SingleStreamBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h_sz = cfg.hidden_size;
let mlp_sz = (h_sz as f64 * cfg.mlp_ratio) as usize;
let head_dim = h_sz / cfg.num_heads;
let linear1 = candle_nn::linear(h_sz, h_sz * 3 + mlp_sz, vb.pp("linear1"))?;
let linear2 = candle_nn::linear(h_sz + mlp_sz, h_sz, vb.pp("linear2"))?;
let norm = QkNorm::new(head_dim, vb.pp("norm"))?;
let pre_norm = layer_norm(h_sz, vb.pp("pre_norm"))?;
let modulation = Modulation::new(h_sz, false, vb.pp("modulation"))?;
Ok(Self {
linear1,
linear2,
norm,
pre_norm,
modulation,
h_sz,
mlp_sz,
num_heads: cfg.num_heads,
})
}
fn forward(&self, xs: &Tensor, vec_: &Tensor, pe: &Tensor) -> Result<Tensor> {
let mod_ = self.modulation.forward(vec_)?;
let (shift, scale, gate) = (&mod_[0], &mod_[1], &mod_[2]);
let x_mod = xs
.apply(&self.pre_norm)?
.broadcast_mul(&(scale + 1.0)?)?
.broadcast_add(shift)?;
let x_mod = x_mod.apply(&self.linear1)?;
let qkv = x_mod.narrow(D::Minus1, 0, 3 * self.h_sz)?;
let (b, l, _khd) = qkv.dims3()?;
let qkv = qkv.reshape((b, l, 3, self.num_heads, ()))?;
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
let mlp = x_mod.narrow(D::Minus1, 3 * self.h_sz, self.mlp_sz)?;
let q = q.apply(&self.norm.query_norm)?;
let k = k.apply(&self.norm.key_norm)?;
let attn = attention(&q, &k, &v, pe)?;
let output = Tensor::cat(&[attn, mlp.gelu()?], 2)?.apply(&self.linear2)?;
xs + gate.broadcast_mul(&output)
}
}
#[derive(Debug, Clone)]
pub struct LastLayer {
norm_final: LayerNorm,
linear: Linear,
ada_ln_modulation: Linear,
}
impl LastLayer {
fn new(h_sz: usize, p_sz: usize, out_c: usize, vb: VarBuilder) -> Result<Self> {
let norm_final = layer_norm(h_sz, vb.pp("norm_final"))?;
let linear = candle_nn::linear(h_sz, p_sz * p_sz * out_c, vb.pp("linear"))?;
let ada_ln_modulation = candle_nn::linear(h_sz, 2 * h_sz, vb.pp("adaLN_modulation.1"))?;
Ok(Self {
norm_final,
linear,
ada_ln_modulation,
})
}
fn forward(&self, xs: &Tensor, vec: &Tensor) -> Result<Tensor> {
let chunks = vec.silu()?.apply(&self.ada_ln_modulation)?.chunk(2, 1)?;
let (shift, scale) = (&chunks[0], &chunks[1]);
let xs = xs
.apply(&self.norm_final)?
.broadcast_mul(&(scale.unsqueeze(1)? + 1.0)?)?
.broadcast_add(&shift.unsqueeze(1)?)?;
xs.apply(&self.linear)
}
}
#[derive(Debug, Clone)]
pub struct Flux {
img_in: Linear,
txt_in: Linear,
time_in: MlpEmbedder,
vector_in: MlpEmbedder,
guidance_in: Option<MlpEmbedder>,
pe_embedder: EmbedNd,
double_blocks: Vec<DoubleStreamBlock>,
single_blocks: Vec<SingleStreamBlock>,
final_layer: LastLayer,
}
impl Flux {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let img_in = candle_nn::linear(cfg.in_channels, cfg.hidden_size, vb.pp("img_in"))?;
let txt_in = candle_nn::linear(cfg.context_in_dim, cfg.hidden_size, vb.pp("txt_in"))?;
let mut double_blocks = Vec::with_capacity(cfg.depth);
let vb_d = vb.pp("double_blocks");
for idx in 0..cfg.depth {
let db = DoubleStreamBlock::new(cfg, vb_d.pp(idx))?;
double_blocks.push(db)
}
let mut single_blocks = Vec::with_capacity(cfg.depth_single_blocks);
let vb_s = vb.pp("single_blocks");
for idx in 0..cfg.depth_single_blocks {
let sb = SingleStreamBlock::new(cfg, vb_s.pp(idx))?;
single_blocks.push(sb)
}
let time_in = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("time_in"))?;
let vector_in = MlpEmbedder::new(cfg.vec_in_dim, cfg.hidden_size, vb.pp("vector_in"))?;
let guidance_in = if cfg.guidance_embed {
let mlp = MlpEmbedder::new(256, cfg.hidden_size, vb.pp("guidance_in"))?;
Some(mlp)
} else {
None
};
let final_layer =
LastLayer::new(cfg.hidden_size, 1, cfg.in_channels, vb.pp("final_layer"))?;
let pe_dim = cfg.hidden_size / cfg.num_heads;
let pe_embedder = EmbedNd::new(pe_dim, cfg.theta, cfg.axes_dim.to_vec());
Ok(Self {
img_in,
txt_in,
time_in,
vector_in,
guidance_in,
pe_embedder,
double_blocks,
single_blocks,
final_layer,
})
}
#[allow(clippy::too_many_arguments)]
pub fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor> {
if txt.rank() != 3 {
candle::bail!("unexpected shape for txt {:?}", txt.shape())
}
if img.rank() != 3 {
candle::bail!("unexpected shape for img {:?}", img.shape())
}
let dtype = img.dtype();
let pe = {
let ids = Tensor::cat(&[txt_ids, img_ids], 1)?;
ids.apply(&self.pe_embedder)?
};
let mut txt = txt.apply(&self.txt_in)?;
let mut img = img.apply(&self.img_in)?;
let vec_ = timestep_embedding(timesteps, 256, dtype)?.apply(&self.time_in)?;
let vec_ = match (self.guidance_in.as_ref(), guidance) {
(Some(g_in), Some(guidance)) => {
(vec_ + timestep_embedding(guidance, 256, dtype)?.apply(g_in))?
}
_ => vec_,
};
let vec_ = (vec_ + y.apply(&self.vector_in))?;
// Double blocks
for block in self.double_blocks.iter() {
(img, txt) = block.forward(&img, &txt, &vec_, &pe)?
}
// Single blocks
let mut img = Tensor::cat(&[&txt, &img], 1)?;
for block in self.single_blocks.iter() {
img = block.forward(&img, &vec_, &pe)?;
}
let img = img.i((.., txt.dim(1)?..))?;
self.final_layer.forward(&img, &vec_)
}
}

View File

@ -0,0 +1,119 @@
use candle::{Device, Result, Tensor};
pub fn get_noise(
num_samples: usize,
height: usize,
width: usize,
device: &Device,
) -> Result<Tensor> {
let height = (height + 15) / 16 * 2;
let width = (width + 15) / 16 * 2;
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
}
#[derive(Debug, Clone)]
pub struct State {
pub img: Tensor,
pub img_ids: Tensor,
pub txt: Tensor,
pub txt_ids: Tensor,
pub vec: Tensor,
}
impl State {
pub fn new(t5_emb: &Tensor, clip_emb: &Tensor, img: &Tensor) -> Result<Self> {
let dtype = img.dtype();
let (bs, c, h, w) = img.dims4()?;
let dev = img.device();
let img = img.reshape((bs, c, h / 2, 2, w / 2, 2))?; // (b, c, h, ph, w, pw)
let img = img.permute((0, 2, 4, 1, 3, 5))?; // (b, h, w, c, ph, pw)
let img = img.reshape((bs, h / 2 * w / 2, c * 4))?;
let img_ids = Tensor::stack(
&[
Tensor::full(0u32, (h / 2, w / 2), dev)?,
Tensor::arange(0u32, h as u32 / 2, dev)?
.reshape(((), 1))?
.broadcast_as((h / 2, w / 2))?,
Tensor::arange(0u32, w as u32 / 2, dev)?
.reshape((1, ()))?
.broadcast_as((h / 2, w / 2))?,
],
2,
)?
.to_dtype(dtype)?;
let img_ids = img_ids.reshape((1, h / 2 * w / 2, 3))?;
let img_ids = img_ids.repeat((bs, 1, 1))?;
let txt = t5_emb.repeat(bs)?;
let txt_ids = Tensor::zeros((bs, txt.dim(1)?, 3), dtype, dev)?;
let vec = clip_emb.repeat(bs)?;
Ok(Self {
img,
img_ids,
txt,
txt_ids,
vec,
})
}
}
fn time_shift(mu: f64, sigma: f64, t: f64) -> f64 {
let e = mu.exp();
e / (e + (1. / t - 1.).powf(sigma))
}
/// `shift` is a triple `(image_seq_len, base_shift, max_shift)`.
pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f64> {
let timesteps: Vec<f64> = (0..=num_steps)
.map(|v| v as f64 / num_steps as f64)
.rev()
.collect();
match shift {
None => timesteps,
Some((image_seq_len, y1, y2)) => {
let (x1, x2) = (256., 4096.);
let m = (y2 - y1) / (x2 - x1);
let b = y1 - m * x1;
let mu = m * image_seq_len as f64 + b;
timesteps
.into_iter()
.map(|v| time_shift(mu, 1., v))
.collect()
}
}
}
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
let (b, _h_w, c_ph_pw) = xs.dims3()?;
let height = (height + 15) / 16;
let width = (width + 15) / 16;
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
}
#[allow(clippy::too_many_arguments)]
pub fn denoise(
model: &super::model::Flux,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
vec_: &Tensor,
timesteps: &[f64],
guidance: f64,
) -> Result<Tensor> {
let b_sz = img.dim(0)?;
let dev = img.device();
let guidance = Tensor::full(guidance as f32, b_sz, dev)?;
let mut img = img.clone();
for window in timesteps.windows(2) {
let (t_curr, t_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let t_vec = Tensor::full(*t_curr as f32, b_sz, dev)?;
let pred = model.forward(&img, img_ids, txt, txt_ids, &t_vec, vec_, Some(&guidance))?;
img = (img + pred * (t_prev - t_curr))?
}
Ok(img)
}

View File

@ -17,6 +17,7 @@ pub mod efficientvit;
pub mod encodec;
pub mod eva2;
pub mod falcon;
pub mod flux;
pub mod gemma;
pub mod hiera;
pub mod jina_bert;