mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Track the conv2d operations in stable-diffusion. (#431)
* Track the conv2d operations in stable-diffusion. * Add more tracing to stable-diffusion. * Also trace the resnet bits. * Trace the attention blocks. * Also trace the attention inner part. * Small tweak.
This commit is contained in:
@ -5,13 +5,15 @@ use crate::attention::{
|
||||
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
||||
};
|
||||
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Downsample2D {
|
||||
conv: Option<nn::Conv2d>,
|
||||
conv: Option<Conv2d>,
|
||||
padding: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Downsample2D {
|
||||
@ -24,17 +26,23 @@ impl Downsample2D {
|
||||
) -> Result<Self> {
|
||||
let conv = if use_conv {
|
||||
let config = nn::Conv2dConfig { stride: 2, padding };
|
||||
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Downsample2D { conv, padding })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
|
||||
Ok(Self {
|
||||
conv,
|
||||
padding,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Downsample2D {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match &self.conv {
|
||||
None => xs.avg_pool2d((2, 2), (2, 2)),
|
||||
Some(conv) => {
|
||||
@ -54,7 +62,8 @@ impl Downsample2D {
|
||||
// This does not support the conv-transpose mode.
|
||||
#[derive(Debug)]
|
||||
struct Upsample2D {
|
||||
conv: nn::Conv2d,
|
||||
conv: Conv2d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Upsample2D {
|
||||
@ -63,13 +72,15 @@ impl Upsample2D {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
Ok(Self { conv })
|
||||
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
|
||||
Ok(Self { conv, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Upsample2D {
|
||||
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = match size {
|
||||
None => {
|
||||
let (_bsize, _channels, h, w) = xs.dims4()?;
|
||||
@ -108,6 +119,7 @@ impl Default for DownEncoderBlock2DConfig {
|
||||
pub struct DownEncoderBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
downsampler: Option<Downsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: DownEncoderBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -147,9 +159,11 @@ impl DownEncoderBlock2D {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
downsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -157,6 +171,7 @@ impl DownEncoderBlock2D {
|
||||
|
||||
impl DownEncoderBlock2D {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, None)?
|
||||
@ -193,6 +208,7 @@ impl Default for UpDecoderBlock2DConfig {
|
||||
pub struct UpDecoderBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
upsampler: Option<Upsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: UpDecoderBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -227,9 +243,11 @@ impl UpDecoderBlock2D {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
upsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -237,6 +255,7 @@ impl UpDecoderBlock2D {
|
||||
|
||||
impl UpDecoderBlock2D {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for resnet in self.resnets.iter() {
|
||||
xs = resnet.forward(&xs, None)?
|
||||
@ -274,6 +293,7 @@ impl Default for UNetMidBlock2DConfig {
|
||||
pub struct UNetMidBlock2D {
|
||||
resnet: ResnetBlock2D,
|
||||
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
|
||||
span: tracing::Span,
|
||||
pub config: UNetMidBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -313,14 +333,17 @@ impl UNetMidBlock2D {
|
||||
)?;
|
||||
attn_resnets.push((attn, resnet))
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
|
||||
Ok(Self {
|
||||
resnet,
|
||||
attn_resnets,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.resnet.forward(xs, temb)?;
|
||||
for (attn, resnet) in self.attn_resnets.iter() {
|
||||
xs = resnet.forward(&attn.forward(&xs)?, temb)?
|
||||
@ -361,6 +384,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
|
||||
pub struct UNetMidBlock2DCrossAttn {
|
||||
resnet: ResnetBlock2D,
|
||||
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
|
||||
span: tracing::Span,
|
||||
pub config: UNetMidBlock2DCrossAttnConfig,
|
||||
}
|
||||
|
||||
@ -408,9 +432,11 @@ impl UNetMidBlock2DCrossAttn {
|
||||
)?;
|
||||
attn_resnets.push((attn, resnet))
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
|
||||
Ok(Self {
|
||||
resnet,
|
||||
attn_resnets,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -421,6 +447,7 @@ impl UNetMidBlock2DCrossAttn {
|
||||
temb: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.resnet.forward(xs, temb)?;
|
||||
for (attn, resnet) in self.attn_resnets.iter() {
|
||||
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
|
||||
@ -458,6 +485,7 @@ impl Default for DownBlock2DConfig {
|
||||
pub struct DownBlock2D {
|
||||
resnets: Vec<ResnetBlock2D>,
|
||||
downsampler: Option<Downsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: DownBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -495,14 +523,17 @@ impl DownBlock2D {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "down2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
downsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
let mut output_states = vec![];
|
||||
for resnet in self.resnets.iter() {
|
||||
@ -547,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
|
||||
pub struct CrossAttnDownBlock2D {
|
||||
downblock: DownBlock2D,
|
||||
attentions: Vec<SpatialTransformer>,
|
||||
span: tracing::Span,
|
||||
pub config: CrossAttnDownBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -585,9 +617,11 @@ impl CrossAttnDownBlock2D {
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
|
||||
Ok(Self {
|
||||
downblock,
|
||||
attentions,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -598,6 +632,7 @@ impl CrossAttnDownBlock2D {
|
||||
temb: Option<&Tensor>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Vec<Tensor>)> {
|
||||
let _enter = self.span.enter();
|
||||
let mut output_states = vec![];
|
||||
let mut xs = xs.clone();
|
||||
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
|
||||
@ -644,6 +679,7 @@ impl Default for UpBlock2DConfig {
|
||||
pub struct UpBlock2D {
|
||||
pub resnets: Vec<ResnetBlock2D>,
|
||||
upsampler: Option<Upsample2D>,
|
||||
span: tracing::Span,
|
||||
pub config: UpBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -687,9 +723,11 @@ impl UpBlock2D {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "up2d");
|
||||
Ok(Self {
|
||||
resnets,
|
||||
upsampler,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -701,6 +739,7 @@ impl UpBlock2D {
|
||||
temb: Option<&Tensor>,
|
||||
upsample_size: Option<(usize, usize)>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for (index, resnet) in self.resnets.iter().enumerate() {
|
||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||
@ -739,6 +778,7 @@ impl Default for CrossAttnUpBlock2DConfig {
|
||||
pub struct CrossAttnUpBlock2D {
|
||||
pub upblock: UpBlock2D,
|
||||
pub attentions: Vec<SpatialTransformer>,
|
||||
span: tracing::Span,
|
||||
pub config: CrossAttnUpBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -779,9 +819,11 @@ impl CrossAttnUpBlock2D {
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
|
||||
Ok(Self {
|
||||
upblock,
|
||||
attentions,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -794,6 +836,7 @@ impl CrossAttnUpBlock2D {
|
||||
upsample_size: Option<(usize, usize)>,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
|
||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||
|
Reference in New Issue
Block a user