img2img pipeline for stable diffusion. (#752)

* img2img pipeline for stable diffusion.

* Rename the arguments + fix.

* Fix for zero strength.

* Another fix.

* Another fix.

* Revert.

* Include the backtrace.

* Noise scaling.

* Fix the height/width.
This commit is contained in:
Laurent Mazare
2023-09-06 08:06:49 +02:00
committed by GitHub
parent 16bf44f6e9
commit 7299a68353
4 changed files with 91 additions and 17 deletions

View File

@ -96,6 +96,15 @@ struct Args {
#[arg(long)]
use_f16: bool,
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
/// The strength, indicates how much to transform the initial image. The
/// value must be between 0 and 1, a value of 1 discards the initial image
/// information.
#[arg(long, default_value_t = 0.8)]
img2img_strength: f64,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
@ -306,6 +315,26 @@ fn text_embeddings(
Ok(text_embeddings)
}
fn image_preprocess<T: AsRef<std::path::Path>>(path: T) -> anyhow::Result<Tensor> {
let img = image::io::Reader::open(path)?.decode()?;
let (height, width) = (img.height() as usize, img.width() as usize);
let height = height - height % 32;
let width = width - width % 32;
let img = img.resize_to_fill(
width as u32,
height as u32,
image::imageops::FilterType::CatmullRom,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?
.unsqueeze(0)?;
Ok(img)
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
@ -328,9 +357,15 @@ fn run(args: Args) -> Result<()> {
tracing,
use_f16,
use_flash_attn,
img2img,
img2img_strength,
..
} = args;
if !(0. ..=1.).contains(&img2img_strength) {
anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}")
}
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -382,25 +417,53 @@ fn run(args: Args) -> Result<()> {
println!("Building the autoencoder.");
let vae_weights = ModelFile::Vae.get(vae_weights, sd_version, use_f16)?;
let vae = sd_config.build_vae(&vae_weights, &device, dtype)?;
let init_latent_dist = match &img2img {
None => None,
Some(image) => {
let image = image_preprocess(image)?.to_device(&device)?;
Some(vae.encode(&image)?)
}
};
println!("Building the unet.");
let unet_weights = ModelFile::Unet.get(unet_weights, sd_version, use_f16)?;
let unet = sd_config.build_unet(&unet_weights, &device, 4, use_flash_attn, dtype)?;
let t_start = if img2img.is_some() {
n_steps - (n_steps as f64 * img2img_strength) as usize
} else {
0
};
let bsize = 1;
for idx in 0..num_samples {
let mut latents = Tensor::randn(
0f32,
1f32,
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
&device,
)?
.to_dtype(dtype)?;
// scale the initial noise by the standard deviation required by the scheduler
latents = (latents * scheduler.init_noise_sigma())?;
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
} else {
latents
}
}
None => {
let latents = Tensor::randn(
0f32,
1f32,
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
&device,
)?;
// scale the initial noise by the standard deviation required by the scheduler
(latents * scheduler.init_noise_sigma())?
}
};
let mut latents = latents.to_dtype(dtype)?;
println!("starting sampling");
for (timestep_index, &timestep) in scheduler.timesteps().iter().enumerate() {
for (timestep_index, &timestep) in timesteps.iter().enumerate() {
if timestep_index < t_start {
continue;
}
let start_time = std::time::Instant::now();
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;