diff --git a/candle-transformers/src/models/wuerstchen/diffnext.rs b/candle-transformers/src/models/wuerstchen/diffnext.rs index 60b799ae..6ea36027 100644 --- a/candle-transformers/src/models/wuerstchen/diffnext.rs +++ b/candle-transformers/src/models/wuerstchen/diffnext.rs @@ -37,7 +37,7 @@ impl ResBlockStageB { let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?; let xs = match x_skip { None => xs.clone(), - Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?, + Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?.contiguous()?, }; let xs = xs .permute((0, 2, 3, 1))? @@ -352,7 +352,9 @@ impl WDiffNeXt { None }; let skip = match (skip, effnet_c.as_ref()) { - (Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?), + (Some(skip), Some(effnet_c)) => { + Some(Tensor::cat(&[skip, effnet_c], 1)?.contiguous()?) + } (None, Some(skip)) | (Some(skip), None) => Some(skip.clone()), (None, None) => None, };