mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Track the conv2d operations in stable-diffusion. (#431)
* Track the conv2d operations in stable-diffusion. * Add more tracing to stable-diffusion. * Also trace the resnet bits. * Trace the attention blocks. * Also trace the attention inner part. * Small tweak.
This commit is contained in:
@ -5,6 +5,7 @@
|
||||
//!
|
||||
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
||||
//! https://arxiv.org/abs/1512.03385
|
||||
use crate::utils::{conv2d, Conv2d};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn as nn;
|
||||
|
||||
@ -45,11 +46,12 @@ impl Default for ResnetBlock2DConfig {
|
||||
#[derive(Debug)]
|
||||
pub struct ResnetBlock2D {
|
||||
norm1: nn::GroupNorm,
|
||||
conv1: nn::Conv2d,
|
||||
conv1: Conv2d,
|
||||
norm2: nn::GroupNorm,
|
||||
conv2: nn::Conv2d,
|
||||
conv2: Conv2d,
|
||||
time_emb_proj: Option<nn::Linear>,
|
||||
conv_shortcut: Option<nn::Conv2d>,
|
||||
conv_shortcut: Option<Conv2d>,
|
||||
span: tracing::Span,
|
||||
config: ResnetBlock2DConfig,
|
||||
}
|
||||
|
||||
@ -65,10 +67,10 @@ impl ResnetBlock2D {
|
||||
padding: 1,
|
||||
};
|
||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||
let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||
let groups_out = config.groups_out.unwrap_or(config.groups);
|
||||
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
|
||||
let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
|
||||
let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
|
||||
let use_in_shortcut = config
|
||||
.use_in_shortcut
|
||||
.unwrap_or(in_channels != out_channels);
|
||||
@ -77,7 +79,7 @@ impl ResnetBlock2D {
|
||||
stride: 1,
|
||||
padding: 0,
|
||||
};
|
||||
Some(nn::conv2d(
|
||||
Some(conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
1,
|
||||
@ -95,18 +97,21 @@ impl ResnetBlock2D {
|
||||
vs.pp("time_emb_proj"),
|
||||
)?),
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
|
||||
Ok(Self {
|
||||
norm1,
|
||||
conv1,
|
||||
norm2,
|
||||
conv2,
|
||||
time_emb_proj,
|
||||
span,
|
||||
config,
|
||||
conv_shortcut,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let shortcut_xs = match &self.conv_shortcut {
|
||||
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
|
||||
None => xs.clone(),
|
||||
|
Reference in New Issue
Block a user