mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
W decoding. (#893)
* W decoding. * Add the diffusion loop. * Use the appropriate config.
This commit is contained in:
@ -248,6 +248,20 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
println!("{prior_text_embeddings}");
|
println!("{prior_text_embeddings}");
|
||||||
|
|
||||||
|
let text_embeddings = {
|
||||||
|
let tokenizer = ModelFile::Tokenizer.get(tokenizer)?;
|
||||||
|
let weights = ModelFile::Clip.get(clip_weights)?;
|
||||||
|
encode_prompt(
|
||||||
|
&prompt,
|
||||||
|
&uncond_prompt,
|
||||||
|
tokenizer.clone(),
|
||||||
|
weights,
|
||||||
|
stable_diffusion::clip::Config::wuerstchen(),
|
||||||
|
&device,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
println!("{prior_text_embeddings}");
|
||||||
|
|
||||||
println!("Building the prior.");
|
println!("Building the prior.");
|
||||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||||
let prior = {
|
let prior = {
|
||||||
@ -262,7 +276,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
println!("Building the vqgan.");
|
println!("Building the vqgan.");
|
||||||
let _vqgan = {
|
let vqgan = {
|
||||||
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
|
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
@ -273,7 +287,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
println!("Building the decoder.");
|
println!("Building the decoder.");
|
||||||
|
|
||||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
|
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
|
||||||
let _decoder = {
|
let decoder = {
|
||||||
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
|
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
@ -314,49 +328,37 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let dt = start_time.elapsed().as_secs_f32();
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||||
}
|
}
|
||||||
let latents = ((latents * 42.)? - 1.)?;
|
let effnet = ((latents * 42.)? - 1.)?;
|
||||||
/*
|
let mut latents = Tensor::randn(
|
||||||
let timesteps = scheduler.timesteps();
|
|
||||||
let latents = Tensor::randn(
|
|
||||||
0f32,
|
0f32,
|
||||||
1f32,
|
1f32,
|
||||||
(bsize, 4, sd_config.height / 8, sd_config.width / 8),
|
(b_size, PRIOR_CIN, latent_height, latent_width),
|
||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
// scale the initial noise by the standard deviation required by the scheduler
|
|
||||||
let mut latents = latents * scheduler.init_noise_sigma()?;
|
|
||||||
|
|
||||||
println!("starting sampling");
|
println!("diffusion process");
|
||||||
for (timestep_index, ×tep) in timesteps.iter().enumerate() {
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
if index == timesteps.len() - 1 {
|
||||||
|
continue;
|
||||||
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
|
}
|
||||||
let noise_pred =
|
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||||
decoder.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
|
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
|
||||||
let noise_pred = noise_pred.chunk(2, 0)?;
|
latents = prior_scheduler.step(&noise_pred, t, &latents)?;
|
||||||
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
|
|
||||||
let noise_pred =
|
|
||||||
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
|
|
||||||
latents = scheduler.step(&noise_pred, timestep, &latents)?;
|
|
||||||
let dt = start_time.elapsed().as_secs_f32();
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"Generating the final image for sample {}/{}.",
|
"Generating the final image for sample {}/{}.",
|
||||||
idx + 1,
|
idx + 1,
|
||||||
num_samples
|
num_samples
|
||||||
);
|
);
|
||||||
/*
|
let image = vqgan.decode(&(&latents * 0.3764)?)?;
|
||||||
let image = vae.decode(&(&latents / 0.18215)?)?;
|
|
||||||
// TODO: Add the clamping between 0 and 1.
|
// TODO: Add the clamping between 0 and 1.
|
||||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||||
candle_examples::save_image(&image, image_filename)?
|
candle_examples::save_image(&image, image_filename)?
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user