Track the conv2d operations in stable-diffusion. (#431)

* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
This commit is contained in:
Laurent Mazare
2023-08-13 16:58:26 +02:00
committed by GitHub
parent b1ff78f762
commit 9af438ac1b
7 changed files with 146 additions and 25 deletions

View File

@ -5,13 +5,15 @@ use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
struct Downsample2D {
conv: Option<nn::Conv2d>,
conv: Option<Conv2d>,
padding: usize,
span: tracing::Span,
}
impl Downsample2D {
@ -24,17 +26,23 @@ impl Downsample2D {
) -> Result<Self> {
let conv = if use_conv {
let config = nn::Conv2dConfig { stride: 2, padding };
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
None
};
Ok(Downsample2D { conv, padding })
let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
Ok(Self {
conv,
padding,
span,
})
}
}
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
None => xs.avg_pool2d((2, 2), (2, 2)),
Some(conv) => {
@ -54,7 +62,8 @@ impl Downsample2D {
// This does not support the conv-transpose mode.
#[derive(Debug)]
struct Upsample2D {
conv: nn::Conv2d,
conv: Conv2d,
span: tracing::Span,
}
impl Upsample2D {
@ -63,13 +72,15 @@ impl Upsample2D {
padding: 1,
..Default::default()
};
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Ok(Self { conv })
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
Ok(Self { conv, span })
}
}
impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = match size {
None => {
let (_bsize, _channels, h, w) = xs.dims4()?;
@ -108,6 +119,7 @@ impl Default for DownEncoderBlock2DConfig {
pub struct DownEncoderBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownEncoderBlock2DConfig,
}
@ -147,9 +159,11 @@ impl DownEncoderBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
@ -157,6 +171,7 @@ impl DownEncoderBlock2D {
impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
@ -193,6 +208,7 @@ impl Default for UpDecoderBlock2DConfig {
pub struct UpDecoderBlock2D {
resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpDecoderBlock2DConfig,
}
@ -227,9 +243,11 @@ impl UpDecoderBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
@ -237,6 +255,7 @@ impl UpDecoderBlock2D {
impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
@ -274,6 +293,7 @@ impl Default for UNetMidBlock2DConfig {
pub struct UNetMidBlock2D {
resnet: ResnetBlock2D,
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DConfig,
}
@ -313,14 +333,17 @@ impl UNetMidBlock2D {
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs)?, temb)?
@ -361,6 +384,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
pub struct UNetMidBlock2DCrossAttn {
resnet: ResnetBlock2D,
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DCrossAttnConfig,
}
@ -408,9 +432,11 @@ impl UNetMidBlock2DCrossAttn {
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
@ -421,6 +447,7 @@ impl UNetMidBlock2DCrossAttn {
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
@ -458,6 +485,7 @@ impl Default for DownBlock2DConfig {
pub struct DownBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownBlock2DConfig,
}
@ -495,14 +523,17 @@ impl DownBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut xs = xs.clone();
let mut output_states = vec![];
for resnet in self.resnets.iter() {
@ -547,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
pub struct CrossAttnDownBlock2D {
downblock: DownBlock2D,
attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnDownBlock2DConfig,
}
@ -585,9 +617,11 @@ impl CrossAttnDownBlock2D {
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
Ok(Self {
downblock,
attentions,
span,
config,
})
}
@ -598,6 +632,7 @@ impl CrossAttnDownBlock2D {
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut output_states = vec![];
let mut xs = xs.clone();
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
@ -644,6 +679,7 @@ impl Default for UpBlock2DConfig {
pub struct UpBlock2D {
pub resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpBlock2DConfig,
}
@ -687,9 +723,11 @@ impl UpBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
@ -701,6 +739,7 @@ impl UpBlock2D {
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
@ -739,6 +778,7 @@ impl Default for CrossAttnUpBlock2DConfig {
pub struct CrossAttnUpBlock2D {
pub upblock: UpBlock2D,
pub attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnUpBlock2DConfig,
}
@ -779,9 +819,11 @@ impl CrossAttnUpBlock2D {
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
Ok(Self {
upblock,
attentions,
span,
config,
})
}
@ -794,6 +836,7 @@ impl CrossAttnUpBlock2D {
upsample_size: Option<(usize, usize)>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;