Track the conv2d operations in stable-diffusion. (#431)

* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
This commit is contained in:
Laurent Mazare
2023-08-13 16:58:26 +02:00
committed by GitHub
parent b1ff78f762
commit 9af438ac1b
7 changed files with 146 additions and 25 deletions

View File

@ -310,12 +310,11 @@ impl<'a> Reduce<'a> {
.iter()
.map(|(u, _)| u)
.product::<usize>();
let mut src_i = 0;
for dst_v in dst.iter_mut() {
for (dst_i, dst_v) in dst.iter_mut().enumerate() {
let src_i = dst_i * reduce_sz;
for &s in src[src_i..src_i + reduce_sz].iter() {
*dst_v = f(*dst_v, s)
}
src_i += reduce_sz
}
return Ok(dst);
};