Fixes for the stable diffusion example. (#342)

* Fixes for the stable diffusion example.

* Bugfix.

* Another fix.

* Fix for group-norm.

* More fixes to get SD to work.
This commit is contained in:
Laurent Mazare
2023-08-08 15:57:09 +02:00
committed by GitHub
parent ab35684326
commit 89d3926c9b
6 changed files with 27 additions and 12 deletions

View File

@ -59,17 +59,21 @@ impl GroupNorm {
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let mut w_dims = vec![1; x_shape.len()];
w_dims[1] = n_channels;
let weight = self.weight.reshape(w_dims.clone())?;
let bias = self.bias.reshape(w_dims)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?
.reshape(x_shape)
.reshape(x_shape)?
.broadcast_mul(&weight)?
.broadcast_add(&bias)
}
}
pub fn group_norm(
num_channels: usize,
num_groups: usize,
num_channels: usize,
eps: f64,
vb: crate::VarBuilder,
) -> Result<GroupNorm> {