Improve the error message on shape mismatch for cat. (#897)

* Improve the error message on shape mismatch for cat.

* Cosmetic tweak.
This commit is contained in:
Laurent Mazare
2023-09-19 15:09:47 +01:00
committed by GitHub
parent 06e46d7c3b
commit 4f91c8e109
4 changed files with 33 additions and 7 deletions

View File

@ -41,6 +41,7 @@ impl ResBlockStageB {
};
let xs = xs
.permute((0, 2, 3, 1))?
.contiguous()?
.apply(&self.channelwise_lin1)?
.gelu()?
.apply(&self.channelwise_grn)?
@ -374,7 +375,7 @@ impl WDiffNeXt {
.apply(&self.clf_ln)?
.apply(&self.clf_conv)?
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
.chunk(1, 2)?;
.chunk(2, 1)?;
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
(x_in - &ab[0])? / b
}

View File

@ -52,6 +52,7 @@ impl Module for MixingResidualBlock {
.affine(1. + mods[3] as f64, mods[4] as f64)?;
let x_temp = x_temp
.permute((0, 2, 3, 1))?
.contiguous()?
.apply(&self.channelwise_lin1)?
.gelu()?
.apply(&self.channelwise_lin2)?