mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat. * Cosmetic tweak.
This commit is contained in:
@ -295,12 +295,10 @@ fn run(args: Args) -> Result<()> {
|
||||
};
|
||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = prior_scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
println!("prior denoising");
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||
@ -358,11 +356,9 @@ fn run(args: Args) -> Result<()> {
|
||||
println!("diffusion process with prior {image_embeddings:?}");
|
||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||
let timesteps = scheduler.timesteps();
|
||||
let timesteps = ×teps[..timesteps.len() - 1];
|
||||
for (index, &t) in timesteps.iter().enumerate() {
|
||||
let start_time = std::time::Instant::now();
|
||||
if index == timesteps.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||
let noise_pred =
|
||||
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||
|
Reference in New Issue
Block a user