Track the conv2d operations in stable-diffusion. (#431)

* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
This commit is contained in:
Laurent Mazare
2023-08-13 16:58:26 +02:00
committed by GitHub
parent b1ff78f762
commit 9af438ac1b
7 changed files with 146 additions and 25 deletions

View File

@ -310,12 +310,11 @@ impl<'a> Reduce<'a> {
.iter() .iter()
.map(|(u, _)| u) .map(|(u, _)| u)
.product::<usize>(); .product::<usize>();
let mut src_i = 0; for (dst_i, dst_v) in dst.iter_mut().enumerate() {
for dst_v in dst.iter_mut() { let src_i = dst_i * reduce_sz;
for &s in src[src_i..src_i + reduce_sz].iter() { for &s in src[src_i..src_i + reduce_sz].iter() {
*dst_v = f(*dst_v, s) *dst_v = f(*dst_v, s)
} }
src_i += reduce_sz
} }
return Ok(dst); return Ok(dst);
}; };

View File

@ -6,17 +6,20 @@ use candle_nn as nn;
#[derive(Debug)] #[derive(Debug)]
struct GeGlu { struct GeGlu {
proj: nn::Linear, proj: nn::Linear,
span: tracing::Span,
} }
impl GeGlu { impl GeGlu {
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> { fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?; let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
Ok(Self { proj }) let span = tracing::span!(tracing::Level::TRACE, "geglu");
Ok(Self { proj, span })
} }
} }
impl GeGlu { impl GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()? &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
} }
@ -27,6 +30,7 @@ impl GeGlu {
struct FeedForward { struct FeedForward {
project_in: GeGlu, project_in: GeGlu,
linear: nn::Linear, linear: nn::Linear,
span: tracing::Span,
} }
impl FeedForward { impl FeedForward {
@ -40,12 +44,18 @@ impl FeedForward {
let vs = vs.pp("net"); let vs = vs.pp("net");
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?; let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?; let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
Ok(Self { project_in, linear }) let span = tracing::span!(tracing::Level::TRACE, "ff");
Ok(Self {
project_in,
linear,
span,
})
} }
} }
impl FeedForward { impl FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?; let xs = self.project_in.forward(xs)?;
self.linear.forward(&xs) self.linear.forward(&xs)
} }
@ -60,6 +70,8 @@ struct CrossAttention {
heads: usize, heads: usize,
scale: f64, scale: f64,
slice_size: Option<usize>, slice_size: Option<usize>,
span: tracing::Span,
span_attn: tracing::Span,
} }
impl CrossAttention { impl CrossAttention {
@ -79,6 +91,8 @@ impl CrossAttention {
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?; let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?; let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?; let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
let span = tracing::span!(tracing::Level::TRACE, "xa");
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
Ok(Self { Ok(Self {
to_q, to_q,
to_k, to_k,
@ -87,6 +101,8 @@ impl CrossAttention {
heads, heads,
scale, scale,
slice_size, slice_size,
span,
span_attn,
}) })
} }
@ -129,12 +145,14 @@ impl CrossAttention {
} }
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> { fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?; let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?; let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
self.reshape_batch_dim_to_heads(&xs) self.reshape_batch_dim_to_heads(&xs)
} }
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?; let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs); let context = context.unwrap_or(xs);
let key = self.to_k.forward(context)?; let key = self.to_k.forward(context)?;
@ -165,6 +183,7 @@ struct BasicTransformerBlock {
norm1: nn::LayerNorm, norm1: nn::LayerNorm,
norm2: nn::LayerNorm, norm2: nn::LayerNorm,
norm3: nn::LayerNorm, norm3: nn::LayerNorm,
span: tracing::Span,
} }
impl BasicTransformerBlock { impl BasicTransformerBlock {
@ -196,6 +215,7 @@ impl BasicTransformerBlock {
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?; let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?; let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?; let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
Ok(Self { Ok(Self {
attn1, attn1,
ff, ff,
@ -203,10 +223,12 @@ impl BasicTransformerBlock {
norm1, norm1,
norm2, norm2,
norm3, norm3,
span,
}) })
} }
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?; let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?; let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
self.ff.forward(&self.norm3.forward(&xs)?)? + xs self.ff.forward(&self.norm3.forward(&xs)?)? + xs
@ -247,6 +269,7 @@ pub struct SpatialTransformer {
proj_in: Proj, proj_in: Proj,
transformer_blocks: Vec<BasicTransformerBlock>, transformer_blocks: Vec<BasicTransformerBlock>,
proj_out: Proj, proj_out: Proj,
span: tracing::Span,
pub config: SpatialTransformerConfig, pub config: SpatialTransformerConfig,
} }
@ -295,16 +318,19 @@ impl SpatialTransformer {
vs.pp("proj_out"), vs.pp("proj_out"),
)?) )?)
}; };
let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
Ok(Self { Ok(Self {
norm, norm,
proj_in, proj_in,
transformer_blocks, transformer_blocks,
proj_out, proj_out,
span,
config, config,
}) })
} }
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let (batch, _channel, height, weight) = xs.dims4()?; let (batch, _channel, height, weight) = xs.dims4()?;
let residual = xs; let residual = xs;
let xs = self.norm.forward(xs)?; let xs = self.norm.forward(xs)?;
@ -376,6 +402,7 @@ pub struct AttentionBlock {
proj_attn: nn::Linear, proj_attn: nn::Linear,
channels: usize, channels: usize,
num_heads: usize, num_heads: usize,
span: tracing::Span,
config: AttentionBlockConfig, config: AttentionBlockConfig,
} }
@ -389,6 +416,7 @@ impl AttentionBlock {
let key = nn::linear(channels, channels, vs.pp("key"))?; let key = nn::linear(channels, channels, vs.pp("key"))?;
let value = nn::linear(channels, channels, vs.pp("value"))?; let value = nn::linear(channels, channels, vs.pp("value"))?;
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?; let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
Ok(Self { Ok(Self {
group_norm, group_norm,
query, query,
@ -397,6 +425,7 @@ impl AttentionBlock {
proj_attn, proj_attn,
channels, channels,
num_heads, num_heads,
span,
config, config,
}) })
} }
@ -406,10 +435,9 @@ impl AttentionBlock {
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))? xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2) .transpose(1, 2)
} }
}
impl AttentionBlock {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs; let residual = xs;
let (batch, channel, height, width) = xs.dims4()?; let (batch, channel, height, width) = xs.dims4()?;
let xs = self let xs = self

View File

@ -40,6 +40,10 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The height in pixels of the generated image. /// The height in pixels of the generated image.
#[arg(long)] #[arg(long)]
height: Option<usize>, height: Option<usize>,
@ -183,6 +187,9 @@ fn output_filename(
} }
fn run(args: Args) -> Result<()> { fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args { let Args {
prompt, prompt,
uncond_prompt, uncond_prompt,
@ -198,8 +205,18 @@ fn run(args: Args) -> Result<()> {
clip_weights, clip_weights,
vae_weights, vae_weights,
unet_weights, unet_weights,
tracing,
.. ..
} = args; } = args;
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let sd_config = match sd_version { let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => { StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)

View File

@ -5,6 +5,7 @@
//! //!
//! Denoising Diffusion Implicit Models, K. He and al, 2015. //! Denoising Diffusion Implicit Models, K. He and al, 2015.
//! https://arxiv.org/abs/1512.03385 //! https://arxiv.org/abs/1512.03385
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D}; use candle::{Result, Tensor, D};
use candle_nn as nn; use candle_nn as nn;
@ -45,11 +46,12 @@ impl Default for ResnetBlock2DConfig {
#[derive(Debug)] #[derive(Debug)]
pub struct ResnetBlock2D { pub struct ResnetBlock2D {
norm1: nn::GroupNorm, norm1: nn::GroupNorm,
conv1: nn::Conv2d, conv1: Conv2d,
norm2: nn::GroupNorm, norm2: nn::GroupNorm,
conv2: nn::Conv2d, conv2: Conv2d,
time_emb_proj: Option<nn::Linear>, time_emb_proj: Option<nn::Linear>,
conv_shortcut: Option<nn::Conv2d>, conv_shortcut: Option<Conv2d>,
span: tracing::Span,
config: ResnetBlock2DConfig, config: ResnetBlock2DConfig,
} }
@ -65,10 +67,10 @@ impl ResnetBlock2D {
padding: 1, padding: 1,
}; };
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
let groups_out = config.groups_out.unwrap_or(config.groups); let groups_out = config.groups_out.unwrap_or(config.groups);
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?; let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?; let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
let use_in_shortcut = config let use_in_shortcut = config
.use_in_shortcut .use_in_shortcut
.unwrap_or(in_channels != out_channels); .unwrap_or(in_channels != out_channels);
@ -77,7 +79,7 @@ impl ResnetBlock2D {
stride: 1, stride: 1,
padding: 0, padding: 0,
}; };
Some(nn::conv2d( Some(conv2d(
in_channels, in_channels,
out_channels, out_channels,
1, 1,
@ -95,18 +97,21 @@ impl ResnetBlock2D {
vs.pp("time_emb_proj"), vs.pp("time_emb_proj"),
)?), )?),
}; };
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
Ok(Self { Ok(Self {
norm1, norm1,
conv1, conv1,
norm2, norm2,
conv2, conv2,
time_emb_proj, time_emb_proj,
span,
config, config,
conv_shortcut, conv_shortcut,
}) })
} }
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut_xs = match &self.conv_shortcut { let shortcut_xs = match &self.conv_shortcut {
Some(conv_shortcut) => conv_shortcut.forward(xs)?, Some(conv_shortcut) => conv_shortcut.forward(xs)?,
None => xs.clone(), None => xs.clone(),

View File

@ -5,6 +5,7 @@
//! timestep and return a denoised version of the input. //! timestep and return a denoised version of the input.
use crate::embeddings::{TimestepEmbedding, Timesteps}; use crate::embeddings::{TimestepEmbedding, Timesteps};
use crate::unet_2d_blocks::*; use crate::unet_2d_blocks::*;
use crate::utils::{conv2d, Conv2d};
use candle::{DType, Result, Tensor}; use candle::{DType, Result, Tensor};
use candle_nn as nn; use candle_nn as nn;
@ -85,14 +86,15 @@ enum UNetUpBlock {
#[derive(Debug)] #[derive(Debug)]
pub struct UNet2DConditionModel { pub struct UNet2DConditionModel {
conv_in: nn::Conv2d, conv_in: Conv2d,
time_proj: Timesteps, time_proj: Timesteps,
time_embedding: TimestepEmbedding, time_embedding: TimestepEmbedding,
down_blocks: Vec<UNetDownBlock>, down_blocks: Vec<UNetDownBlock>,
mid_block: UNetMidBlock2DCrossAttn, mid_block: UNetMidBlock2DCrossAttn,
up_blocks: Vec<UNetUpBlock>, up_blocks: Vec<UNetUpBlock>,
conv_norm_out: nn::GroupNorm, conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d, conv_out: Conv2d,
span: tracing::Span,
config: UNet2DConditionModelConfig, config: UNet2DConditionModelConfig,
} }
@ -112,7 +114,7 @@ impl UNet2DConditionModel {
stride: 1, stride: 1,
padding: 1, padding: 1,
}; };
let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?; let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift); let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
let time_embedding = let time_embedding =
@ -263,7 +265,8 @@ impl UNet2DConditionModel {
config.norm_eps, config.norm_eps,
vs.pp("conv_norm_out"), vs.pp("conv_norm_out"),
)?; )?;
let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?; let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
let span = tracing::span!(tracing::Level::TRACE, "unet2d");
Ok(Self { Ok(Self {
conv_in, conv_in,
time_proj, time_proj,
@ -273,18 +276,18 @@ impl UNet2DConditionModel {
up_blocks, up_blocks,
conv_norm_out, conv_norm_out,
conv_out, conv_out,
span,
config, config,
}) })
} }
}
impl UNet2DConditionModel {
pub fn forward( pub fn forward(
&self, &self,
xs: &Tensor, xs: &Tensor,
timestep: f64, timestep: f64,
encoder_hidden_states: &Tensor, encoder_hidden_states: &Tensor,
) -> Result<Tensor> { ) -> Result<Tensor> {
let _enter = self.span.enter();
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None) self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
} }

View File

@ -5,13 +5,15 @@ use crate::attention::{
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig, AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
}; };
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig}; use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use crate::utils::{conv2d, Conv2d};
use candle::{Result, Tensor, D}; use candle::{Result, Tensor, D};
use candle_nn as nn; use candle_nn as nn;
#[derive(Debug)] #[derive(Debug)]
struct Downsample2D { struct Downsample2D {
conv: Option<nn::Conv2d>, conv: Option<Conv2d>,
padding: usize, padding: usize,
span: tracing::Span,
} }
impl Downsample2D { impl Downsample2D {
@ -24,17 +26,23 @@ impl Downsample2D {
) -> Result<Self> { ) -> Result<Self> {
let conv = if use_conv { let conv = if use_conv {
let config = nn::Conv2dConfig { stride: 2, padding }; let config = nn::Conv2dConfig { stride: 2, padding };
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Some(conv) Some(conv)
} else { } else {
None None
}; };
Ok(Downsample2D { conv, padding }) let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
Ok(Self {
conv,
padding,
span,
})
} }
} }
impl Downsample2D { impl Downsample2D {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
match &self.conv { match &self.conv {
None => xs.avg_pool2d((2, 2), (2, 2)), None => xs.avg_pool2d((2, 2), (2, 2)),
Some(conv) => { Some(conv) => {
@ -54,7 +62,8 @@ impl Downsample2D {
// This does not support the conv-transpose mode. // This does not support the conv-transpose mode.
#[derive(Debug)] #[derive(Debug)]
struct Upsample2D { struct Upsample2D {
conv: nn::Conv2d, conv: Conv2d,
span: tracing::Span,
} }
impl Upsample2D { impl Upsample2D {
@ -63,13 +72,15 @@ impl Upsample2D {
padding: 1, padding: 1,
..Default::default() ..Default::default()
}; };
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
Ok(Self { conv }) let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
Ok(Self { conv, span })
} }
} }
impl Upsample2D { impl Upsample2D {
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> { fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = match size { let xs = match size {
None => { None => {
let (_bsize, _channels, h, w) = xs.dims4()?; let (_bsize, _channels, h, w) = xs.dims4()?;
@ -108,6 +119,7 @@ impl Default for DownEncoderBlock2DConfig {
pub struct DownEncoderBlock2D { pub struct DownEncoderBlock2D {
resnets: Vec<ResnetBlock2D>, resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>, downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownEncoderBlock2DConfig, pub config: DownEncoderBlock2DConfig,
} }
@ -147,9 +159,11 @@ impl DownEncoderBlock2D {
} else { } else {
None None
}; };
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
Ok(Self { Ok(Self {
resnets, resnets,
downsampler, downsampler,
span,
config, config,
}) })
} }
@ -157,6 +171,7 @@ impl DownEncoderBlock2D {
impl DownEncoderBlock2D { impl DownEncoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for resnet in self.resnets.iter() { for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)? xs = resnet.forward(&xs, None)?
@ -193,6 +208,7 @@ impl Default for UpDecoderBlock2DConfig {
pub struct UpDecoderBlock2D { pub struct UpDecoderBlock2D {
resnets: Vec<ResnetBlock2D>, resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>, upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpDecoderBlock2DConfig, pub config: UpDecoderBlock2DConfig,
} }
@ -227,9 +243,11 @@ impl UpDecoderBlock2D {
} else { } else {
None None
}; };
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
Ok(Self { Ok(Self {
resnets, resnets,
upsampler, upsampler,
span,
config, config,
}) })
} }
@ -237,6 +255,7 @@ impl UpDecoderBlock2D {
impl UpDecoderBlock2D { impl UpDecoderBlock2D {
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for resnet in self.resnets.iter() { for resnet in self.resnets.iter() {
xs = resnet.forward(&xs, None)? xs = resnet.forward(&xs, None)?
@ -274,6 +293,7 @@ impl Default for UNetMidBlock2DConfig {
pub struct UNetMidBlock2D { pub struct UNetMidBlock2D {
resnet: ResnetBlock2D, resnet: ResnetBlock2D,
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>, attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DConfig, pub config: UNetMidBlock2DConfig,
} }
@ -313,14 +333,17 @@ impl UNetMidBlock2D {
)?; )?;
attn_resnets.push((attn, resnet)) attn_resnets.push((attn, resnet))
} }
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
Ok(Self { Ok(Self {
resnet, resnet,
attn_resnets, attn_resnets,
span,
config, config,
}) })
} }
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> { pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?; let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() { for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs)?, temb)? xs = resnet.forward(&attn.forward(&xs)?, temb)?
@ -361,6 +384,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
pub struct UNetMidBlock2DCrossAttn { pub struct UNetMidBlock2DCrossAttn {
resnet: ResnetBlock2D, resnet: ResnetBlock2D,
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>, attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
span: tracing::Span,
pub config: UNetMidBlock2DCrossAttnConfig, pub config: UNetMidBlock2DCrossAttnConfig,
} }
@ -408,9 +432,11 @@ impl UNetMidBlock2DCrossAttn {
)?; )?;
attn_resnets.push((attn, resnet)) attn_resnets.push((attn, resnet))
} }
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
Ok(Self { Ok(Self {
resnet, resnet,
attn_resnets, attn_resnets,
span,
config, config,
}) })
} }
@ -421,6 +447,7 @@ impl UNetMidBlock2DCrossAttn {
temb: Option<&Tensor>, temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> { ) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = self.resnet.forward(xs, temb)?; let mut xs = self.resnet.forward(xs, temb)?;
for (attn, resnet) in self.attn_resnets.iter() { for (attn, resnet) in self.attn_resnets.iter() {
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)? xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
@ -458,6 +485,7 @@ impl Default for DownBlock2DConfig {
pub struct DownBlock2D { pub struct DownBlock2D {
resnets: Vec<ResnetBlock2D>, resnets: Vec<ResnetBlock2D>,
downsampler: Option<Downsample2D>, downsampler: Option<Downsample2D>,
span: tracing::Span,
pub config: DownBlock2DConfig, pub config: DownBlock2DConfig,
} }
@ -495,14 +523,17 @@ impl DownBlock2D {
} else { } else {
None None
}; };
let span = tracing::span!(tracing::Level::TRACE, "down2d");
Ok(Self { Ok(Self {
resnets, resnets,
downsampler, downsampler,
span,
config, config,
}) })
} }
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> { pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
let mut output_states = vec![]; let mut output_states = vec![];
for resnet in self.resnets.iter() { for resnet in self.resnets.iter() {
@ -547,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
pub struct CrossAttnDownBlock2D { pub struct CrossAttnDownBlock2D {
downblock: DownBlock2D, downblock: DownBlock2D,
attentions: Vec<SpatialTransformer>, attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnDownBlock2DConfig, pub config: CrossAttnDownBlock2DConfig,
} }
@ -585,9 +617,11 @@ impl CrossAttnDownBlock2D {
) )
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
Ok(Self { Ok(Self {
downblock, downblock,
attentions, attentions,
span,
config, config,
}) })
} }
@ -598,6 +632,7 @@ impl CrossAttnDownBlock2D {
temb: Option<&Tensor>, temb: Option<&Tensor>,
encoder_hidden_states: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>,
) -> Result<(Tensor, Vec<Tensor>)> { ) -> Result<(Tensor, Vec<Tensor>)> {
let _enter = self.span.enter();
let mut output_states = vec![]; let mut output_states = vec![];
let mut xs = xs.clone(); let mut xs = xs.clone();
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) { for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
@ -644,6 +679,7 @@ impl Default for UpBlock2DConfig {
pub struct UpBlock2D { pub struct UpBlock2D {
pub resnets: Vec<ResnetBlock2D>, pub resnets: Vec<ResnetBlock2D>,
upsampler: Option<Upsample2D>, upsampler: Option<Upsample2D>,
span: tracing::Span,
pub config: UpBlock2DConfig, pub config: UpBlock2DConfig,
} }
@ -687,9 +723,11 @@ impl UpBlock2D {
} else { } else {
None None
}; };
let span = tracing::span!(tracing::Level::TRACE, "up2d");
Ok(Self { Ok(Self {
resnets, resnets,
upsampler, upsampler,
span,
config, config,
}) })
} }
@ -701,6 +739,7 @@ impl UpBlock2D {
temb: Option<&Tensor>, temb: Option<&Tensor>,
upsample_size: Option<(usize, usize)>, upsample_size: Option<(usize, usize)>,
) -> Result<Tensor> { ) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for (index, resnet) in self.resnets.iter().enumerate() { for (index, resnet) in self.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
@ -739,6 +778,7 @@ impl Default for CrossAttnUpBlock2DConfig {
pub struct CrossAttnUpBlock2D { pub struct CrossAttnUpBlock2D {
pub upblock: UpBlock2D, pub upblock: UpBlock2D,
pub attentions: Vec<SpatialTransformer>, pub attentions: Vec<SpatialTransformer>,
span: tracing::Span,
pub config: CrossAttnUpBlock2DConfig, pub config: CrossAttnUpBlock2DConfig,
} }
@ -779,9 +819,11 @@ impl CrossAttnUpBlock2D {
) )
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
Ok(Self { Ok(Self {
upblock, upblock,
attentions, attentions,
span,
config, config,
}) })
} }
@ -794,6 +836,7 @@ impl CrossAttnUpBlock2D {
upsample_size: Option<(usize, usize)>, upsample_size: Option<(usize, usize)>,
encoder_hidden_states: Option<&Tensor>, encoder_hidden_states: Option<&Tensor>,
) -> Result<Tensor> { ) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for (index, resnet) in self.upblock.resnets.iter().enumerate() { for (index, resnet) in self.upblock.resnets.iter().enumerate() {
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?; xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;

View File

@ -29,3 +29,29 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
image.save(p).map_err(candle::Error::wrap)?; image.save(p).map_err(candle::Error::wrap)?;
Ok(()) Ok(())
} }
// Wrap the conv2d op to provide some tracing.
#[derive(Debug)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
}
impl Conv2d {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: candle_nn::Conv2dConfig,
vs: candle_nn::VarBuilder,
) -> Result<Conv2d> {
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}