mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Add Stable Diffusion 3 Example (#2558)
* Add stable diffusion 3 example Add get_qkv_linear to handle different dimensionality in linears Add stable diffusion 3 example Add use_quant_conv and use_post_quant_conv for vae in stable diffusion adapt existing AutoEncoderKLConfig to the change add forward_until_encoder_layer to ClipTextTransformer rename sd3 config to sd3_medium in mmdit; minor clean-up Enable flash-attn for mmdit impl when the feature is enabled. Add sd3 example codebase add document crediting references pass the cargo fmt test pass the clippy test * fix typos * expose cfg_scale and time_shift as options * Replace the sample image with JPG version. Change image output format accordingly. * make meaningful error messages * remove the tail-end assignment in sd3_vae_vb_rename * remove the CUDA requirement * use default_value in clap args * add use_flash_attn to turn on/off flash-attn for MMDiT at runtime * resolve clippy errors and warnings * use default_value_t * Pin the web-sys dependency. * Clippy fix. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -467,6 +467,24 @@ pub struct AttentionBlock {
|
||||
config: AttentionBlockConfig,
|
||||
}
|
||||
|
||||
// In the .safetensor weights of official Stable Diffusion 3 Medium Huggingface repo
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-3-medium
|
||||
// Linear layer may use a different dimension for the weight in the linear, which is
|
||||
// incompatible with the current implementation of the nn::linear constructor.
|
||||
// This is a workaround to handle the different dimensions.
|
||||
fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result<nn::Linear> {
|
||||
match vs.get((channels, channels), "weight") {
|
||||
Ok(_) => nn::linear(channels, channels, vs),
|
||||
Err(_) => {
|
||||
let weight = vs
|
||||
.get((channels, channels, 1, 1), "weight")?
|
||||
.reshape((channels, channels))?;
|
||||
let bias = vs.get((channels,), "bias")?;
|
||||
Ok(nn::Linear::new(weight, Some(bias)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result<Self> {
|
||||
let num_head_channels = config.num_head_channels.unwrap_or(channels);
|
||||
@ -478,10 +496,10 @@ impl AttentionBlock {
|
||||
} else {
|
||||
("query", "key", "value", "proj_attn")
|
||||
};
|
||||
let query = nn::linear(channels, channels, vs.pp(q_path))?;
|
||||
let key = nn::linear(channels, channels, vs.pp(k_path))?;
|
||||
let value = nn::linear(channels, channels, vs.pp(v_path))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?;
|
||||
let query = get_qkv_linear(channels, vs.pp(q_path))?;
|
||||
let key = get_qkv_linear(channels, vs.pp(k_path))?;
|
||||
let value = get_qkv_linear(channels, vs.pp(v_path))?;
|
||||
let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
|
Reference in New Issue
Block a user