mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat. * Cosmetic tweak.
This commit is contained in:
@ -41,6 +41,7 @@ impl ResBlockStageB {
|
||||
};
|
||||
let xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.contiguous()?
|
||||
.apply(&self.channelwise_lin1)?
|
||||
.gelu()?
|
||||
.apply(&self.channelwise_grn)?
|
||||
@ -374,7 +375,7 @@ impl WDiffNeXt {
|
||||
.apply(&self.clf_ln)?
|
||||
.apply(&self.clf_conv)?
|
||||
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
|
||||
.chunk(1, 2)?;
|
||||
.chunk(2, 1)?;
|
||||
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
|
||||
(x_in - &ab[0])? / b
|
||||
}
|
||||
|
@ -52,6 +52,7 @@ impl Module for MixingResidualBlock {
|
||||
.affine(1. + mods[3] as f64, mods[4] as f64)?;
|
||||
let x_temp = x_temp
|
||||
.permute((0, 2, 3, 1))?
|
||||
.contiguous()?
|
||||
.apply(&self.channelwise_lin1)?
|
||||
.gelu()?
|
||||
.apply(&self.channelwise_lin2)?
|
||||
|
Reference in New Issue
Block a user