mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00

* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
94 lines
2.8 KiB
Rust
94 lines
2.8 KiB
Rust
use anyhow::{Ok, Result};
|
|
use candle_transformers::models::stable_diffusion::vae;
|
|
|
|
pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result<vae::AutoEncoderKL> {
|
|
let config = vae::AutoEncoderKLConfig {
|
|
block_out_channels: vec![128, 256, 512, 512],
|
|
layers_per_block: 2,
|
|
latent_channels: 16,
|
|
norm_num_groups: 32,
|
|
use_quant_conv: false,
|
|
use_post_quant_conv: false,
|
|
};
|
|
Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?)
|
|
}
|
|
|
|
pub fn sd3_vae_vb_rename(name: &str) -> String {
|
|
let parts: Vec<&str> = name.split('.').collect();
|
|
let mut result = Vec::new();
|
|
let mut i = 0;
|
|
|
|
while i < parts.len() {
|
|
match parts[i] {
|
|
"down_blocks" => {
|
|
result.push("down");
|
|
}
|
|
"mid_block" => {
|
|
result.push("mid");
|
|
}
|
|
"up_blocks" => {
|
|
result.push("up");
|
|
match parts[i + 1] {
|
|
// Reverse the order of up_blocks.
|
|
"0" => result.push("3"),
|
|
"1" => result.push("2"),
|
|
"2" => result.push("1"),
|
|
"3" => result.push("0"),
|
|
_ => {}
|
|
}
|
|
i += 1; // Skip the number after up_blocks.
|
|
}
|
|
"resnets" => {
|
|
if i > 0 && parts[i - 1] == "mid_block" {
|
|
match parts[i + 1] {
|
|
"0" => result.push("block_1"),
|
|
"1" => result.push("block_2"),
|
|
_ => {}
|
|
}
|
|
i += 1; // Skip the number after resnets.
|
|
} else {
|
|
result.push("block");
|
|
}
|
|
}
|
|
"downsamplers" => {
|
|
result.push("downsample");
|
|
i += 1; // Skip the 0 after downsamplers.
|
|
}
|
|
"conv_shortcut" => {
|
|
result.push("nin_shortcut");
|
|
}
|
|
"attentions" => {
|
|
if parts[i + 1] == "0" {
|
|
result.push("attn_1")
|
|
}
|
|
i += 1; // Skip the number after attentions.
|
|
}
|
|
"group_norm" => {
|
|
result.push("norm");
|
|
}
|
|
"query" => {
|
|
result.push("q");
|
|
}
|
|
"key" => {
|
|
result.push("k");
|
|
}
|
|
"value" => {
|
|
result.push("v");
|
|
}
|
|
"proj_attn" => {
|
|
result.push("proj_out");
|
|
}
|
|
"conv_norm_out" => {
|
|
result.push("norm_out");
|
|
}
|
|
"upsamplers" => {
|
|
result.push("upsample");
|
|
i += 1; // Skip the 0 after upsamplers.
|
|
}
|
|
part => result.push(part),
|
|
}
|
|
i += 1;
|
|
}
|
|
result.join(".")
|
|
}
|