mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support sd3.5 medium and MMDiT-X (#2587)
* extract attn out of joint_attn * further adjust attn and joint_attn * add mmdit-x support * support sd3.5-medium in the example * update README.md
This commit is contained in:
@ -19,13 +19,15 @@ enum Which {
|
||||
V3_5Large,
|
||||
#[value(name = "3.5-large-turbo")]
|
||||
V3_5LargeTurbo,
|
||||
#[value(name = "3.5-medium")]
|
||||
V3_5Medium,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_3_5(&self) -> bool {
|
||||
match self {
|
||||
Self::V3Medium => false,
|
||||
Self::V3_5Large | Self::V3_5LargeTurbo => true,
|
||||
Self::V3_5Large | Self::V3_5LargeTurbo | Self::V3_5Medium => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -117,36 +119,59 @@ fn main() -> Result<()> {
|
||||
let default_inference_steps = match which {
|
||||
Which::V3_5Large => 28,
|
||||
Which::V3_5LargeTurbo => 4,
|
||||
Which::V3_5Medium => 28,
|
||||
Which::V3Medium => 28,
|
||||
};
|
||||
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
|
||||
let default_cfg_scale = match which {
|
||||
Which::V3_5Large => 4.0,
|
||||
Which::V3_5LargeTurbo => 1.0,
|
||||
Which::V3_5Medium => 4.0,
|
||||
Which::V3Medium => 4.0,
|
||||
};
|
||||
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);
|
||||
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
|
||||
let sai_repo = {
|
||||
let sai_repo_for_text_encoders = {
|
||||
let name = match which {
|
||||
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
|
||||
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
|
||||
// Unfortunately, stabilityai/stable-diffusion-3.5-medium doesn't have the monolithic text encoders that's usually
|
||||
// placed under the text_encoders directory, like the case in stabilityai/stable-diffusion-3.5-large and -large-turbo.
|
||||
// To make things worse, it currently only has partitioned model.fp16-00001-of-00002.safetensors and model.fp16-00002-of-00002.safetensors
|
||||
// under the text_encoder_3 directory, for the t5xxl_fp16.safetensors model. This means that we need to merge the two partitions
|
||||
// to get the monolithic text encoders. This is not a trivial task.
|
||||
// Since the situation can change, we do not want to spend efforts to handle the uniqueness of stabilityai/stable-diffusion-3.5-medium,
|
||||
// which involves different paths and merging the two partitions files for t5xxl_fp16.safetensors.
|
||||
// so for now, we'll use the text encoder models from the stabilityai/stable-diffusion-3.5-large repository.
|
||||
// TODO: Change to "stabilityai/stable-diffusion-3.5-medium" once the maintainers of the repository add back the monolithic text encoders.
|
||||
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-large",
|
||||
Which::V3Medium => unreachable!(),
|
||||
};
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?;
|
||||
let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?;
|
||||
let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?;
|
||||
let sai_repo_for_mmdit = {
|
||||
let name = match which {
|
||||
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
|
||||
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
|
||||
Which::V3_5Medium => "stabilityai/stable-diffusion-3.5-medium",
|
||||
Which::V3Medium => unreachable!(),
|
||||
};
|
||||
api.repo(hf_hub::Repo::model(name.to_string()))
|
||||
};
|
||||
let clip_g_file = sai_repo_for_text_encoders.get("text_encoders/clip_g.safetensors")?;
|
||||
let clip_l_file = sai_repo_for_text_encoders.get("text_encoders/clip_l.safetensors")?;
|
||||
let t5xxl_file = sai_repo_for_text_encoders.get("text_encoders/t5xxl_fp16.safetensors")?;
|
||||
let model_file = {
|
||||
let model_file = match which {
|
||||
Which::V3_5Large => "sd3.5_large.safetensors",
|
||||
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
|
||||
Which::V3_5Medium => "sd3.5_medium.safetensors",
|
||||
Which::V3Medium => unreachable!(),
|
||||
};
|
||||
sai_repo.get(model_file)?
|
||||
sai_repo_for_mmdit.get(model_file)?
|
||||
};
|
||||
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
|
||||
&clip_g_file,
|
||||
@ -157,7 +182,12 @@ fn main() -> Result<()> {
|
||||
let vb = unsafe {
|
||||
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
|
||||
};
|
||||
(MMDiTConfig::sd3_5_large(), triple, vb)
|
||||
match which {
|
||||
Which::V3_5Large => (MMDiTConfig::sd3_5_large(), triple, vb),
|
||||
Which::V3_5LargeTurbo => (MMDiTConfig::sd3_5_large(), triple, vb),
|
||||
Which::V3_5Medium => (MMDiTConfig::sd3_5_medium(), triple, vb),
|
||||
Which::V3Medium => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
let sai_repo = {
|
||||
let name = "stabilityai/stable-diffusion-3-medium";
|
||||
|
Reference in New Issue
Block a user