Track the conv2d operations in stable-diffusion. (#431)

* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
This commit is contained in:
Laurent Mazare
2023-08-13 16:58:26 +02:00
committed by GitHub
parent b1ff78f762
commit 9af438ac1b
7 changed files with 146 additions and 25 deletions

View File

@ -310,12 +310,11 @@ impl<'a> Reduce<'a> {
.iter()
.map(|(u, _)| u)
.product::<usize>();
let mut src_i = 0;
for dst_v in dst.iter_mut() {
for (dst_i, dst_v) in dst.iter_mut().enumerate() {
let src_i = dst_i * reduce_sz;
for &s in src[src_i..src_i + reduce_sz].iter() {
*dst_v = f(*dst_v, s)
}
src_i += reduce_sz
}
return Ok(dst);
};

View File

@ -6,17 +6,20 @@ use candle_nn as nn;
#[derive(Debug)]
struct GeGlu {
proj: nn::Linear,
span: tracing::Span,
}
impl GeGlu {
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
Ok(Self { proj })
let span = tracing::span!(tracing::Level::TRACE, "geglu");
Ok(Self { proj, span })
}
}
impl GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
}
@ -27,6 +30,7 @@ impl GeGlu {
struct FeedForward {
project_in: GeGlu,
linear: nn::Linear,
span: tracing::Span,
}
impl FeedForward {
@ -40,12 +44,18 @@ impl FeedForward {
let vs = vs.pp("net");
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
Ok(Self { project_in, linear })
let span = tracing::span!(tracing::Level::TRACE, "ff");
Ok(Self {
project_in,
linear,
span,
})
}
}
impl FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?;
self.linear.forward(&xs)
}
@ -60,6 +70,8 @@ struct CrossAttention {
heads: usize,
scale: f64,
slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
}
impl CrossAttention {
@ -79,6 +91,8 @@ impl CrossAttention {
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
let span = tracing::span!(tracing::Level::TRACE, "xa");
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
Ok(Self {
to_q,
to_k,
@ -87,6 +101,8 @@ impl CrossAttention {
heads,
scale,
slice_size,
span,
span_attn,
})
}
@ -129,12 +145,14 @@ impl CrossAttention {
}
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
self.reshape_batch_dim_to_heads(&xs)
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs);
let key = self.to_k.forward(context)?;
@ -165,6 +183,7 @@ struct BasicTransformerBlock {
norm1: nn::LayerNorm,
norm2: nn::LayerNorm,
norm3: nn::LayerNorm,
span: tracing::Span,
}
impl BasicTransformerBlock {
@ -196,6 +215,7 @@ impl BasicTransformerBlock {
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
Ok(Self {
attn1,
ff,
@ -203,10 +223,12 @@ impl BasicTransformerBlock {
norm1,
norm2,
norm3,
span,
})
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
@ -247,6 +269,7 @@ pub struct SpatialTransformer {
proj_in: Proj,
transformer_blocks: Vec<BasicTransformerBlock>,
proj_out: Proj,
span: tracing::Span,
pub config: SpatialTransformerConfig,
}
@ -295,16 +318,19 @@ impl SpatialTransformer {
vs.pp("proj_out"),
)?)
};
let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
Ok(Self {
norm,
proj_in,
transformer_blocks,
proj_out,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (batch, _channel, height, weight) = xs.dims4()?;
let residual = xs;
let xs = self.norm.forward(xs)?;
@ -376,6 +402,7 @@ pub struct AttentionBlock {
proj_attn: nn::Linear,
channels: usize,
num_heads: usize,
span: tracing::Span,
config: AttentionBlockConfig,
}
@ -389,6 +416,7 @@ impl AttentionBlock {
let key = nn::linear(channels, channels, vs.pp("key"))?;
let value = nn::linear(channels, channels, vs.pp("value"))?;
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
Ok(Self {
group_norm,
query,
@ -397,6 +425,7 @@ impl AttentionBlock {
proj_attn,
channels,
num_heads,
span,
config,
})
}
@ -406,10 +435,9 @@ impl AttentionBlock {
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2)
}
}
impl AttentionBlock {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let (batch, channel, height, width) = xs.dims4()?;
let xs = self

View File

@ -40,6 +40,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
@ -183,6 +187,9 @@ fn output_filename(
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
@ -198,8 +205,18 @@ fn run(args: Args) -> Result<()> {
clip_weights,
vae_weights,
unet_weights,
tracing,
..
} = args;
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)

View File

@ -5,6 +5,7 @@
//!
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
@ -45,11 +46,12 @@ impl Default for ResnetBlock2DConfig {
#[derive(Debug)]
pub struct ResnetBlock2D {
norm1: nn::GroupNorm,
conv1: nn::Conv2d,
conv1: Conv2d,
norm2: nn::GroupNorm,
conv2: nn::Conv2d,
conv2: Conv2d,
time_emb_proj: Option<nn::Linear>,
conv_shortcut: Option<nn::Conv2d>,
conv_shortcut: Option<Conv2d>,
span: tracing::Span,
config: ResnetBlock2DConfig,
}
@ -65,10 +67,10 @@ impl ResnetBlock2D {
padding: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
let groups_out = config.groups_out.unwrap_or(config.groups);
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
let use_in_shortcut = config
.use_in_shortcut
.unwrap_or(in_channels != out_channels);
@ -77,7 +79,7 @@ impl ResnetBlock2D {
stride: 1,
padding: 0,
};
Some(nn::conv2d(
Some(conv2d(
in_channels,
out_channels,
1,
@ -95,18 +97,21 @@ impl ResnetBlock2D {
vs.pp("time_emb_proj"),
)?),
};
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
Ok(Self {
norm1,
conv1,
norm2,
conv2,
time_emb_proj,
span,
config,
conv_shortcut,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut_xs = match &self.conv_shortcut {
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
None => xs.clone(),

View File

@ -5,6 +5,7 @@
//! timestep and return a denoised version of the input.
use crate::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use candle::{DType, Result, Tensor};
use candle_nn as nn;
@ -85,14 +86,15 @@ enum UNetUpBlock {
#[derive(Debug)]
pub struct UNet2DConditionModel {
conv_in: nn::Conv2d,
conv_in: Conv2d,
time_proj: Timesteps,
time_embedding: TimestepEmbedding,
down_blocks: Vec<UNetDownBlock>,
mid_block: UNetMidBlock2DCrossAttn,
up_blocks: Vec<UNetUpBlock>,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
conv_out: Conv2d,
span: tracing::Span,
config: UNet2DConditionModelConfig,
}
@ -112,7 +114,7 @@ impl UNet2DConditionModel {
stride: 1,
padding: 1,
};
let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
let time_embedding =
@ -263,7 +265,8 @@ impl UNet2DConditionModel {
config.norm_eps,
vs.pp("conv_norm_out"),
)?;
let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
let span = tracing::span!(tracing::Level::TRACE, "unet2d");
Ok(Self {
conv_in,
time_proj,
@ -273,18 +276,18 @@ impl UNet2DConditionModel {
up_blocks,
conv_norm_out,
conv_out,
span,
config,
})
}
}
impl UNet2DConditionModel {
pub fn forward(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
) -> Result<Tensor> {
let _enter = self.span.enter();
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
}

View File

@ -5,13 +5,15 @@ use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
struct Downsample2D {
conv: Option<nn::Conv2d>,
conv: Option<Conv2d>,
padding: usize,
span: tracing::Span,
}
impl Downsample2D {
@ -24,17 +26,23 @@ impl Downsample2D {
) -> Result<Self> {
let conv = if use_conv {
let config = nn::Conv2dConfig { stride: 2, padding };
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv)
} else {
None
};
Ok(Downsample2D { conv, padding })
let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
Ok(Self {
conv,
padding,
span,
})
}
}
impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv {
None => xs.avg_pool2d((2, 2), (2, 2)),
Some(conv) => {
@ -54,7 +62,8 @@ impl Downsample2D {
// This does not support the conv-transpose mode.
#[derive(Debug)]
struct Upsample2D {
conv: nn::Conv2d,
conv: Conv2d,
span: tracing::Span,
}
impl Upsample2D {
@ -63,13 +72,15 @@ impl Upsample2D {
padding: 1,
..Default::default()
};
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Ok(Self { conv })
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
Ok(Self { conv, span })
}
}
impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = match size {
None => {
let (_bsize, _channels, h, w) = xs.dims4()?;
@ -108,6 +119,7 @@ impl Default for DownEncoderBlock2DConfig {
pub struct DownEncoderBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownEncoderBlock2DConfig,
}
@ -147,9 +159,11 @@ impl DownEncoderBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
@ -157,6 +171,7 @@ impl DownEncoderBlock2D {
impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
@ -193,6 +208,7 @@ impl Default for UpDecoderBlock2DConfig {
pub struct UpDecoderBlock2D {
resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpDecoderBlock2DConfig,
}
@ -227,9 +243,11 @@ impl UpDecoderBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
@ -237,6 +255,7 @@ impl UpDecoderBlock2D {
impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)?
@ -274,6 +293,7 @@ impl Default for UNetMidBlock2DConfig {
pub struct UNetMidBlock2D {
resnet: ResnetBlock2D,
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DConfig,
}
@ -313,14 +333,17 @@ impl UNetMidBlock2D {
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs)?, temb)?
@ -361,6 +384,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
pub struct UNetMidBlock2DCrossAttn {
resnet: ResnetBlock2D,
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DCrossAttnConfig,
}
@ -408,9 +432,11 @@ impl UNetMidBlock2DCrossAttn {
)?;
attn_resnets.push((attn, resnet))
}
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
Ok(Self {
resnet,
attn_resnets,
span,
config,
})
}
@ -421,6 +447,7 @@ impl UNetMidBlock2DCrossAttn {
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
@ -458,6 +485,7 @@ impl Default for DownBlock2DConfig {
pub struct DownBlock2D {
resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownBlock2DConfig,
}
@ -495,14 +523,17 @@ impl DownBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "down2d");
Ok(Self {
resnets,
downsampler,
span,
config,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut xs = xs.clone();
let mut output_states = vec![];
for resnet in self.resnets.iter() {
@ -547,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
pub struct CrossAttnDownBlock2D {
downblock: DownBlock2D,
attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnDownBlock2DConfig,
}
@ -585,9 +617,11 @@ impl CrossAttnDownBlock2D {
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
Ok(Self {
downblock,
attentions,
span,
config,
})
}
@ -598,6 +632,7 @@ impl CrossAttnDownBlock2D {
temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut output_states = vec![];
let mut xs = xs.clone();
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
@ -644,6 +679,7 @@ impl Default for UpBlock2DConfig {
pub struct UpBlock2D {
pub resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpBlock2DConfig,
}
@ -687,9 +723,11 @@ impl UpBlock2D {
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "up2d");
Ok(Self {
resnets,
upsampler,
span,
config,
})
}
@ -701,6 +739,7 @@ impl UpBlock2D {
temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
@ -739,6 +778,7 @@ impl Default for CrossAttnUpBlock2DConfig {
pub struct CrossAttnUpBlock2D {
pub upblock: UpBlock2D,
pub attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnUpBlock2DConfig,
}
@ -779,9 +819,11 @@ impl CrossAttnUpBlock2D {
)
})
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
Ok(Self {
upblock,
attentions,
span,
config,
})
}
@ -794,6 +836,7 @@ impl CrossAttnUpBlock2D {
upsample_size: Option<(usize, usize)>,
encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;

View File

@ -29,3 +29,29 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
// Wrap the conv2d op to provide some tracing.
#[derive(Debug)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
}
impl Conv2d {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: candle_nn::Conv2dConfig,
vs: candle_nn::VarBuilder,
) -> Result<Conv2d> {
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}