From 4f91c8e1097fcbcd38aaf4a8ebf4f619d7598473 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 19 Sep 2023 15:09:47 +0100 Subject: [PATCH] Improve the error message on shape mismatch for cat. (#897) * Improve the error message on shape mismatch for cat. * Cosmetic tweak. --- candle-core/src/tensor.rs | 28 +++++++++++++++++++ candle-examples/examples/wuerstchen/main.rs | 8 ++---- .../src/models/wuerstchen/diffnext.rs | 3 +- .../src/models/wuerstchen/paella_vq.rs | 1 + 4 files changed, 33 insertions(+), 7 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 61f576cf..756fedb2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1907,6 +1907,34 @@ impl Tensor { for arg in args { arg.as_ref().check_dim(dim, "cat")?; } + for (arg_idx, arg) in args.iter().enumerate() { + let arg = arg.as_ref(); + if arg0.rank() != arg.rank() { + Err(Error::UnexpectedNumberOfDims { + expected: arg0.rank(), + got: arg.rank(), + shape: arg.shape().clone(), + } + .bt())? + } + for (dim_idx, (v1, v2)) in arg0 + .shape() + .dims() + .iter() + .zip(arg.shape().dims().iter()) + .enumerate() + { + if dim_idx != dim && v1 != v2 { + Err(Error::ShapeMismatchCat { + dim: dim_idx, + first_shape: arg0.shape().clone(), + n: arg_idx + 1, + nth_shape: arg.shape().clone(), + } + .bt())? + } + } + } if dim == 0 { Self::cat0(args) } else { diff --git a/candle-examples/examples/wuerstchen/main.rs b/candle-examples/examples/wuerstchen/main.rs index 4e4bce0b..8064f87f 100644 --- a/candle-examples/examples/wuerstchen/main.rs +++ b/candle-examples/examples/wuerstchen/main.rs @@ -295,12 +295,10 @@ fn run(args: Args) -> Result<()> { }; let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; let timesteps = prior_scheduler.timesteps(); + let timesteps = ×teps[..timesteps.len() - 1]; println!("prior denoising"); for (index, &t) in timesteps.iter().enumerate() { let start_time = std::time::Instant::now(); - if index == timesteps.len() - 1 { - continue; - } let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?; let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?; let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?; @@ -358,11 +356,9 @@ fn run(args: Args) -> Result<()> { println!("diffusion process with prior {image_embeddings:?}"); let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?; let timesteps = scheduler.timesteps(); + let timesteps = ×teps[..timesteps.len() - 1]; for (index, &t) in timesteps.iter().enumerate() { let start_time = std::time::Instant::now(); - if index == timesteps.len() - 1 { - continue; - } let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?; let noise_pred = decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?; diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 60b799ae..501a2776 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -41,6 +41,7 @@ impl ResBlockStageB { }; let xs = xs .permute((0, 2, 3, 1))? + .contiguous()? .apply(&self.channelwise_lin1)? .gelu()? .apply(&self.channelwise_grn)? @@ -374,7 +375,7 @@ impl WDiffNeXt { .apply(&self.clf_ln)? .apply(&self.clf_conv)? .apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))? - .chunk(1, 2)?; + .chunk(2, 1)?; let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?; (x_in - &ab[0])? / b } diff --git a/candle-transformers/src/models/wuerstchen/paella_vq.rs b/candle-transformers/src/models/wuerstchen/paella_vq.rs index 8cf33505..6da7362c 100644 --- a/candle-transformers/src/models/wuerstchen/paella_vq.rs +++ b/candle-transformers/src/models/wuerstchen/paella_vq.rs @@ -52,6 +52,7 @@ impl Module for MixingResidualBlock { .affine(1. + mods[3] as f64, mods[4] as f64)?; let x_temp = x_temp .permute((0, 2, 3, 1))? + .contiguous()? .apply(&self.channelwise_lin1)? .gelu()? .apply(&self.channelwise_lin2)?