Improve the error message on shape mismatch for cat. (#897)

* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
This commit is contained in:
Laurent Mazare
2023-09-19 15:09:47 +01:00
committed by GitHub
parent 06e46d7c3b
commit 4f91c8e109
4 changed files with 33 additions and 7 deletions

View File

@ -295,12 +295,10 @@ fn run(args: Args) -> Result<()> {
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
println!("prior denoising");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
@ -358,11 +356,9 @@ fn run(args: Args) -> Result<()> {
println!("diffusion process with prior {image_embeddings:?}");
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
let noise_pred =
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;