Improve the error message on shape mismatch for cat. (#897)

* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
This commit is contained in:
Laurent Mazare
2023-09-19 15:09:47 +01:00
committed by GitHub
parent 06e46d7c3b
commit 4f91c8e109
4 changed files with 33 additions and 7 deletions

View File

@ -1907,6 +1907,34 @@ impl Tensor {
for arg in args {
arg.as_ref().check_dim(dim, "cat")?;
}
for (arg_idx, arg) in args.iter().enumerate() {
let arg = arg.as_ref();
if arg0.rank() != arg.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: arg0.rank(),
got: arg.rank(),
shape: arg.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in arg0
.shape()
.dims()
.iter()
.zip(arg.shape().dims().iter())
.enumerate()
{
if dim_idx != dim && v1 != v2 {
Err(Error::ShapeMismatchCat {
dim: dim_idx,
first_shape: arg0.shape().clone(),
n: arg_idx + 1,
nth_shape: arg.shape().clone(),
}
.bt())?
}
}
}
if dim == 0 {
Self::cat0(args)
} else {

View File

@ -295,12 +295,10 @@ fn run(args: Args) -> Result<()> {
};
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = prior_scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
println!("prior denoising");
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
@ -358,11 +356,9 @@ fn run(args: Args) -> Result<()> {
println!("diffusion process with prior {image_embeddings:?}");
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
let timesteps = scheduler.timesteps();
let timesteps = &timesteps[..timesteps.len() - 1];
for (index, &t) in timesteps.iter().enumerate() {
let start_time = std::time::Instant::now();
if index == timesteps.len() - 1 {
continue;
}
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
let noise_pred =
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;

View File

@ -41,6 +41,7 @@ impl ResBlockStageB {
};
let xs = xs
.permute((0, 2, 3, 1))?
.contiguous()?
.apply(&self.channelwise_lin1)?
.gelu()?
.apply(&self.channelwise_grn)?
@ -374,7 +375,7 @@ impl WDiffNeXt {
.apply(&self.clf_ln)?
.apply(&self.clf_conv)?
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
.chunk(1, 2)?;
.chunk(2, 1)?;
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
(x_in - &ab[0])? / b
}

View File

@ -52,6 +52,7 @@ impl Module for MixingResidualBlock {
.affine(1. + mods[3] as f64, mods[4] as f64)?;
let x_temp = x_temp
.permute((0, 2, 3, 1))?
.contiguous()?
.apply(&self.channelwise_lin1)?
.gelu()?
.apply(&self.channelwise_lin2)?