mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
img2img pipeline for stable diffusion. (#752)
* img2img pipeline for stable diffusion. * Rename the arguments + fix. * Fix for zero strength. * Another fix. * Another fix. * Revert. * Include the backtrace. * Noise scaling. * Fix the height/width.
This commit is contained in:
@ -96,6 +96,15 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
use_f16: bool,
|
||||
|
||||
#[arg(long, value_name = "FILE")]
|
||||
img2img: Option<String>,
|
||||
|
||||
/// The strength, indicates how much to transform the initial image. The
|
||||
/// value must be between 0 and 1, a value of 1 discards the initial image
|
||||
/// information.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
img2img_strength: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
@ -306,6 +315,26 @@ fn text_embeddings(
|
||||
Ok(text_embeddings)
|
||||
}
|
||||
|
||||
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
|
||||
let img = image::io::Reader::open(path)?.decode()?;
|
||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||
let height = height - height % 32;
|
||||
let width = width - width % 32;
|
||||
let img = img.resize_to_fill(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
let img = img.to_rgb8();
|
||||
let img = img.into_raw();
|
||||
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
|
||||
.permute((2, 0, 1))?
|
||||
.to_dtype(DType::F32)?
|
||||
.affine(2. / 255., -1.)?
|
||||
.unsqueeze(0)?;
|
||||
Ok(img)
|
||||
}
|
||||
|
||||
fn run(args: Args) -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
@ -328,9 +357,15 @@ fn run(args: Args) -> Result<()> {
|
||||
tracing,
|
||||
use_f16,
|
||||
use_flash_attn,
|
||||
img2img,
|
||||
img2img_strength,
|
||||
..
|
||||
} = args;
|
||||
|
||||
if !(0. ..=1.).contains(&img2img_strength) {
|
||||
anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
|
||||
}
|
||||
|
||||
let _guard = if tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -382,25 +417,53 @@ fn run(args: Args) -> Result<()> {
|
||||
println!("Building the autoencoder.");
|
||||
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
|
||||
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
|
||||
let init_latent_dist = match &img2img {
|
||||
None => None,
|
||||
Some(image) => {
|
||||
let image = image_preprocess(image)?.to_device(&device)?;
|
||||
Some(vae.encode(&image)?)
|
||||
}
|
||||
};
|
||||
println!("Building the unet.");
|
||||
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
|
||||
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
|
||||
|
||||
let t_start = if img2img.is_some() {
|
||||
n_steps - (n_steps as f64 * img2img_strength) as usize
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let bsize = 1;
|
||||
for idx in 0..num_samples {
|
||||
let mut latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
||||
&device,
|
||||
)?
|
||||
.to_dtype(dtype)?;
|
||||
|
||||
// scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = (latents * scheduler.init_noise_sigma())?;
|
||||
let timesteps = scheduler.timesteps();
|
||||
let latents = match &init_latent_dist {
|
||||
Some(init_latent_dist) => {
|
||||
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
|
||||
if t_start < timesteps.len() {
|
||||
let noise = latents.randn_like(0f64, 1f64)?;
|
||||
scheduler.add_noise(&latents, noise, timesteps[t_start])?
|
||||
} else {
|
||||
latents
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let latents = Tensor::randn(
|
||||
0f32,
|
||||
1f32,
|
||||
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
||||
&device,
|
||||
)?;
|
||||
// scale the initial noise by the standard deviation required by the scheduler
|
||||
(latents * scheduler.init_noise_sigma())?
|
||||
}
|
||||
};
|
||||
let mut latents = latents.to_dtype(dtype)?;
|
||||
|
||||
println!("starting sampling");
|
||||
for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() {
|
||||
for (timestep_index, ×tep) in timesteps.iter().enumerate() {
|
||||
if timestep_index < t_start {
|
||||
continue;
|
||||
}
|
||||
let start_time = std::time::Instant::now();
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
|
||||
|
Reference in New Issue
Block a user