mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Fixes for the stable diffusion example. (#342)
* Fixes for the stable diffusion example. * Bugfix. * Another fix. * Fix for group-norm. * More fixes to get SD to work.
This commit is contained in:
@ -59,17 +59,21 @@ impl GroupNorm {
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let mut w_dims = vec![1; x_shape.len()];
|
||||
w_dims[1] = n_channels;
|
||||
let weight = self.weight.reshape(w_dims.clone())?;
|
||||
let bias = self.bias.reshape(w_dims)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?
|
||||
.reshape(x_shape)
|
||||
.reshape(x_shape)?
|
||||
.broadcast_mul(&weight)?
|
||||
.broadcast_add(&bias)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn group_norm(
|
||||
num_channels: usize,
|
||||
num_groups: usize,
|
||||
num_channels: usize,
|
||||
eps: f64,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<GroupNorm> {
|
||||
|
Reference in New Issue
Block a user