Fixes for the stable diffusion example. (#342)

* Fixes for the stable diffusion example.

* Bugfix.

* Another fix.

* Fix for group-norm.

* More fixes to get SD to work.
This commit is contained in:
Laurent Mazare
2023-08-08 15:57:09 +02:00
committed by GitHub
parent ab35684326
commit 89d3926c9b
6 changed files with 27 additions and 12 deletions

View File

@ -30,8 +30,8 @@ use test_utils::to_vec3_round;
#[test]
fn group_norm() -> Result<()> {
let device = &Device::Cpu;
let w = Tensor::new(&[1f32], device)?;
let b = Tensor::new(&[0f32], device)?;
let w = Tensor::from_vec(vec![1f32; 6], 6, device)?;
let b = Tensor::from_vec(vec![0f32; 6], 6, device)?;
let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?;
let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?;