mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Preliminary support for SDXL. (#647)
* Preliminary support for SDXL. * More SDXL support. * More SDXL. * Use the proper clip config. * Querying for existing tensors. * More robust test.
This commit is contained in:
@ -8,6 +8,7 @@ pub struct StableDiffusionConfig {
|
||||
pub width: usize,
|
||||
pub height: usize,
|
||||
pub clip: clip::Config,
|
||||
pub clip2: Option<clip::Config>,
|
||||
autoencoder: vae::AutoEncoderKLConfig,
|
||||
unet: unet_2d::UNet2DConditionModelConfig,
|
||||
scheduler: ddim::DDIMSchedulerConfig,
|
||||
@ -51,7 +52,7 @@ impl StableDiffusionConfig {
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
512
|
||||
@ -68,6 +69,7 @@ impl StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v1_5(),
|
||||
clip2: None,
|
||||
autoencoder,
|
||||
scheduler: Default::default(),
|
||||
unet,
|
||||
@ -118,7 +120,7 @@ impl StableDiffusionConfig {
|
||||
};
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "heigh has to be divisible by 8");
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
768
|
||||
@ -135,6 +137,7 @@ impl StableDiffusionConfig {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::v2_1(),
|
||||
clip2: None,
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
@ -155,6 +158,83 @@ impl StableDiffusionConfig {
|
||||
)
|
||||
}
|
||||
|
||||
fn sdxl_(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
prediction_type: PredictionType,
|
||||
) -> Self {
|
||||
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
|
||||
out_channels,
|
||||
use_cross_attn,
|
||||
attention_head_dim,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
|
||||
let unet = unet_2d::UNet2DConditionModelConfig {
|
||||
blocks: vec![bc(320, false, 5), bc(640, false, 10), bc(1280, true, 20)],
|
||||
center_input_sample: false,
|
||||
cross_attention_dim: 2048,
|
||||
downsample_padding: 1,
|
||||
flip_sin_to_cos: true,
|
||||
freq_shift: 0.,
|
||||
layers_per_block: 2,
|
||||
mid_block_scale_factor: 1.,
|
||||
norm_eps: 1e-5,
|
||||
norm_num_groups: 32,
|
||||
sliced_attention_size,
|
||||
use_linear_projection: true,
|
||||
};
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
|
||||
let autoencoder = vae::AutoEncoderKLConfig {
|
||||
block_out_channels: vec![128, 256, 512, 512],
|
||||
layers_per_block: 2,
|
||||
latent_channels: 4,
|
||||
norm_num_groups: 32,
|
||||
};
|
||||
let scheduler = ddim::DDIMSchedulerConfig {
|
||||
prediction_type,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let height = if let Some(height) = height {
|
||||
assert_eq!(height % 8, 0, "height has to be divisible by 8");
|
||||
height
|
||||
} else {
|
||||
1024
|
||||
};
|
||||
|
||||
let width = if let Some(width) = width {
|
||||
assert_eq!(width % 8, 0, "width has to be divisible by 8");
|
||||
width
|
||||
} else {
|
||||
1024
|
||||
};
|
||||
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
clip: clip::Config::sdxl(),
|
||||
clip2: Some(clip::Config::sdxl2()),
|
||||
autoencoder,
|
||||
scheduler,
|
||||
unet,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sdxl(
|
||||
sliced_attention_size: Option<usize>,
|
||||
height: Option<usize>,
|
||||
width: Option<usize>,
|
||||
) -> Self {
|
||||
Self::sdxl_(
|
||||
sliced_attention_size,
|
||||
height,
|
||||
width,
|
||||
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/scheduler/scheduler_config.json
|
||||
PredictionType::Epsilon,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn build_vae<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
vae_weights: P,
|
||||
@ -193,17 +273,17 @@ impl StableDiffusionConfig {
|
||||
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
|
||||
ddim::DDIMScheduler::new(n_steps, self.scheduler)
|
||||
}
|
||||
|
||||
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
||||
&self,
|
||||
clip_weights: P,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Result<clip::ClipTextTransformer> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
||||
let text_model = clip::ClipTextTransformer::new(vs, &self.clip)?;
|
||||
Ok(text_model)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
||||
clip: &clip::Config,
|
||||
clip_weights: P,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Result<clip::ClipTextTransformer> {
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
||||
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
|
||||
Ok(text_model)
|
||||
}
|
||||
|
Reference in New Issue
Block a user