mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Fix for cudnn to work with img2img. (#753)
This commit is contained in:
@ -754,6 +754,7 @@ impl UpBlock2D {
|
|||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for (index, resnet) in self.resnets.iter().enumerate() {
|
for (index, resnet) in self.resnets.iter().enumerate() {
|
||||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||||
|
xs = xs.contiguous()?;
|
||||||
xs = resnet.forward(&xs, temb)?;
|
xs = resnet.forward(&xs, temb)?;
|
||||||
}
|
}
|
||||||
match &self.upsampler {
|
match &self.upsampler {
|
||||||
@ -855,6 +856,7 @@ impl CrossAttnUpBlock2D {
|
|||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
|
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
|
||||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||||
|
xs = xs.contiguous()?;
|
||||||
xs = resnet.forward(&xs, temb)?;
|
xs = resnet.forward(&xs, temb)?;
|
||||||
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
|
xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user