mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Track the conv2d operations in stable-diffusion. (#431)
* Track the conv2d operations in stable-diffusion. * Add more tracing to stable-diffusion. * Also trace the resnet bits. * Trace the attention blocks. * Also trace the attention inner part. * Small tweak.
This commit is contained in:
@ -310,12 +310,11 @@ impl<'a> Reduce<'a> {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|(u, _)| u)
|
.map(|(u, _)| u)
|
||||||
.product::<usize>();
|
.product::<usize>();
|
||||||
let mut src_i = 0;
|
for (dst_i, dst_v) in dst.iter_mut().enumerate() {
|
||||||
for dst_v in dst.iter_mut() {
|
let src_i = dst_i * reduce_sz;
|
||||||
for &s in src[src_i..src_i + reduce_sz].iter() {
|
for &s in src[src_i..src_i + reduce_sz].iter() {
|
||||||
*dst_v = f(*dst_v, s)
|
*dst_v = f(*dst_v, s)
|
||||||
}
|
}
|
||||||
src_i += reduce_sz
|
|
||||||
}
|
}
|
||||||
return Ok(dst);
|
return Ok(dst);
|
||||||
};
|
};
|
||||||
|
@ -6,17 +6,20 @@ use candle_nn as nn;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct GeGlu {
|
struct GeGlu {
|
||||||
proj: nn::Linear,
|
proj: nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GeGlu {
|
impl GeGlu {
|
||||||
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
|
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
|
||||||
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
|
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
|
||||||
Ok(Self { proj })
|
let span = tracing::span!(tracing::Level::TRACE, "geglu");
|
||||||
|
Ok(Self { proj, span })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GeGlu {
|
impl GeGlu {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||||
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
|
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
|
||||||
}
|
}
|
||||||
@ -27,6 +30,7 @@ impl GeGlu {
|
|||||||
struct FeedForward {
|
struct FeedForward {
|
||||||
project_in: GeGlu,
|
project_in: GeGlu,
|
||||||
linear: nn::Linear,
|
linear: nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeedForward {
|
impl FeedForward {
|
||||||
@ -40,12 +44,18 @@ impl FeedForward {
|
|||||||
let vs = vs.pp("net");
|
let vs = vs.pp("net");
|
||||||
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
|
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
|
||||||
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
|
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
|
||||||
Ok(Self { project_in, linear })
|
let span = tracing::span!(tracing::Level::TRACE, "ff");
|
||||||
|
Ok(Self {
|
||||||
|
project_in,
|
||||||
|
linear,
|
||||||
|
span,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FeedForward {
|
impl FeedForward {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = self.project_in.forward(xs)?;
|
let xs = self.project_in.forward(xs)?;
|
||||||
self.linear.forward(&xs)
|
self.linear.forward(&xs)
|
||||||
}
|
}
|
||||||
@ -60,6 +70,8 @@ struct CrossAttention {
|
|||||||
heads: usize,
|
heads: usize,
|
||||||
scale: f64,
|
scale: f64,
|
||||||
slice_size: Option<usize>,
|
slice_size: Option<usize>,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_attn: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CrossAttention {
|
impl CrossAttention {
|
||||||
@ -79,6 +91,8 @@ impl CrossAttention {
|
|||||||
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
|
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
|
||||||
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
|
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
|
||||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "xa");
|
||||||
|
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
to_q,
|
to_q,
|
||||||
to_k,
|
to_k,
|
||||||
@ -87,6 +101,8 @@ impl CrossAttention {
|
|||||||
heads,
|
heads,
|
||||||
scale,
|
scale,
|
||||||
slice_size,
|
slice_size,
|
||||||
|
span,
|
||||||
|
span_attn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,12 +145,14 @@ impl CrossAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span_attn.enter();
|
||||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||||
self.reshape_batch_dim_to_heads(&xs)
|
self.reshape_batch_dim_to_heads(&xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let query = self.to_q.forward(xs)?;
|
let query = self.to_q.forward(xs)?;
|
||||||
let context = context.unwrap_or(xs);
|
let context = context.unwrap_or(xs);
|
||||||
let key = self.to_k.forward(context)?;
|
let key = self.to_k.forward(context)?;
|
||||||
@ -165,6 +183,7 @@ struct BasicTransformerBlock {
|
|||||||
norm1: nn::LayerNorm,
|
norm1: nn::LayerNorm,
|
||||||
norm2: nn::LayerNorm,
|
norm2: nn::LayerNorm,
|
||||||
norm3: nn::LayerNorm,
|
norm3: nn::LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BasicTransformerBlock {
|
impl BasicTransformerBlock {
|
||||||
@ -196,6 +215,7 @@ impl BasicTransformerBlock {
|
|||||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||||
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
|
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
attn1,
|
attn1,
|
||||||
ff,
|
ff,
|
||||||
@ -203,10 +223,12 @@ impl BasicTransformerBlock {
|
|||||||
norm1,
|
norm1,
|
||||||
norm2,
|
norm2,
|
||||||
norm3,
|
norm3,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
|
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
|
||||||
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
|
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
|
||||||
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
|
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
|
||||||
@ -247,6 +269,7 @@ pub struct SpatialTransformer {
|
|||||||
proj_in: Proj,
|
proj_in: Proj,
|
||||||
transformer_blocks: Vec<BasicTransformerBlock>,
|
transformer_blocks: Vec<BasicTransformerBlock>,
|
||||||
proj_out: Proj,
|
proj_out: Proj,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: SpatialTransformerConfig,
|
pub config: SpatialTransformerConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,16 +318,19 @@ impl SpatialTransformer {
|
|||||||
vs.pp("proj_out"),
|
vs.pp("proj_out"),
|
||||||
)?)
|
)?)
|
||||||
};
|
};
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
norm,
|
norm,
|
||||||
proj_in,
|
proj_in,
|
||||||
transformer_blocks,
|
transformer_blocks,
|
||||||
proj_out,
|
proj_out,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (batch, _channel, height, weight) = xs.dims4()?;
|
let (batch, _channel, height, weight) = xs.dims4()?;
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs = self.norm.forward(xs)?;
|
let xs = self.norm.forward(xs)?;
|
||||||
@ -376,6 +402,7 @@ pub struct AttentionBlock {
|
|||||||
proj_attn: nn::Linear,
|
proj_attn: nn::Linear,
|
||||||
channels: usize,
|
channels: usize,
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
|
span: tracing::Span,
|
||||||
config: AttentionBlockConfig,
|
config: AttentionBlockConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -389,6 +416,7 @@ impl AttentionBlock {
|
|||||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
||||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
||||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
group_norm,
|
group_norm,
|
||||||
query,
|
query,
|
||||||
@ -397,6 +425,7 @@ impl AttentionBlock {
|
|||||||
proj_attn,
|
proj_attn,
|
||||||
channels,
|
channels,
|
||||||
num_heads,
|
num_heads,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -406,10 +435,9 @@ impl AttentionBlock {
|
|||||||
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
||||||
.transpose(1, 2)
|
.transpose(1, 2)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl AttentionBlock {
|
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let (batch, channel, height, width) = xs.dims4()?;
|
let (batch, channel, height, width) = xs.dims4()?;
|
||||||
let xs = self
|
let xs = self
|
||||||
|
@ -40,6 +40,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
/// The height in pixels of the generated image.
|
/// The height in pixels of the generated image.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
height: Option<usize>,
|
height: Option<usize>,
|
||||||
@ -183,6 +187,9 @@ fn output_filename(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn run(args: Args) -> Result<()> {
|
fn run(args: Args) -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let Args {
|
let Args {
|
||||||
prompt,
|
prompt,
|
||||||
uncond_prompt,
|
uncond_prompt,
|
||||||
@ -198,8 +205,18 @@ fn run(args: Args) -> Result<()> {
|
|||||||
clip_weights,
|
clip_weights,
|
||||||
vae_weights,
|
vae_weights,
|
||||||
unet_weights,
|
unet_weights,
|
||||||
|
tracing,
|
||||||
..
|
..
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
|
let _guard = if tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let sd_config = match sd_version {
|
let sd_config = match sd_version {
|
||||||
StableDiffusionVersion::V1_5 => {
|
StableDiffusionVersion::V1_5 => {
|
||||||
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
|
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
||||||
//! https://arxiv.org/abs/1512.03385
|
//! https://arxiv.org/abs/1512.03385
|
||||||
|
use crate::utils::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
@ -45,11 +46,12 @@ impl Default for ResnetBlock2DConfig {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ResnetBlock2D {
|
pub struct ResnetBlock2D {
|
||||||
norm1: nn::GroupNorm,
|
norm1: nn::GroupNorm,
|
||||||
conv1: nn::Conv2d,
|
conv1: Conv2d,
|
||||||
norm2: nn::GroupNorm,
|
norm2: nn::GroupNorm,
|
||||||
conv2: nn::Conv2d,
|
conv2: Conv2d,
|
||||||
time_emb_proj: Option<nn::Linear>,
|
time_emb_proj: Option<nn::Linear>,
|
||||||
conv_shortcut: Option<nn::Conv2d>,
|
conv_shortcut: Option<Conv2d>,
|
||||||
|
span: tracing::Span,
|
||||||
config: ResnetBlock2DConfig,
|
config: ResnetBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,10 +67,10 @@ impl ResnetBlock2D {
|
|||||||
padding: 1,
|
padding: 1,
|
||||||
};
|
};
|
||||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||||
let conv1 = nn::conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||||
let groups_out = config.groups_out.unwrap_or(config.groups);
|
let groups_out = config.groups_out.unwrap_or(config.groups);
|
||||||
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
|
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
|
||||||
let conv2 = nn::conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
|
let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
|
||||||
let use_in_shortcut = config
|
let use_in_shortcut = config
|
||||||
.use_in_shortcut
|
.use_in_shortcut
|
||||||
.unwrap_or(in_channels != out_channels);
|
.unwrap_or(in_channels != out_channels);
|
||||||
@ -77,7 +79,7 @@ impl ResnetBlock2D {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
padding: 0,
|
padding: 0,
|
||||||
};
|
};
|
||||||
Some(nn::conv2d(
|
Some(conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
1,
|
1,
|
||||||
@ -95,18 +97,21 @@ impl ResnetBlock2D {
|
|||||||
vs.pp("time_emb_proj"),
|
vs.pp("time_emb_proj"),
|
||||||
)?),
|
)?),
|
||||||
};
|
};
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
norm1,
|
norm1,
|
||||||
conv1,
|
conv1,
|
||||||
norm2,
|
norm2,
|
||||||
conv2,
|
conv2,
|
||||||
time_emb_proj,
|
time_emb_proj,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
conv_shortcut,
|
conv_shortcut,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let shortcut_xs = match &self.conv_shortcut {
|
let shortcut_xs = match &self.conv_shortcut {
|
||||||
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
|
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
|
||||||
None => xs.clone(),
|
None => xs.clone(),
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
//! timestep and return a denoised version of the input.
|
//! timestep and return a denoised version of the input.
|
||||||
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
use crate::embeddings::{TimestepEmbedding, Timesteps};
|
||||||
use crate::unet_2d_blocks::*;
|
use crate::unet_2d_blocks::*;
|
||||||
|
use crate::utils::{conv2d, Conv2d};
|
||||||
use candle::{DType, Result, Tensor};
|
use candle::{DType, Result, Tensor};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
@ -85,14 +86,15 @@ enum UNetUpBlock {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct UNet2DConditionModel {
|
pub struct UNet2DConditionModel {
|
||||||
conv_in: nn::Conv2d,
|
conv_in: Conv2d,
|
||||||
time_proj: Timesteps,
|
time_proj: Timesteps,
|
||||||
time_embedding: TimestepEmbedding,
|
time_embedding: TimestepEmbedding,
|
||||||
down_blocks: Vec<UNetDownBlock>,
|
down_blocks: Vec<UNetDownBlock>,
|
||||||
mid_block: UNetMidBlock2DCrossAttn,
|
mid_block: UNetMidBlock2DCrossAttn,
|
||||||
up_blocks: Vec<UNetUpBlock>,
|
up_blocks: Vec<UNetUpBlock>,
|
||||||
conv_norm_out: nn::GroupNorm,
|
conv_norm_out: nn::GroupNorm,
|
||||||
conv_out: nn::Conv2d,
|
conv_out: Conv2d,
|
||||||
|
span: tracing::Span,
|
||||||
config: UNet2DConditionModelConfig,
|
config: UNet2DConditionModelConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,7 +114,7 @@ impl UNet2DConditionModel {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
padding: 1,
|
padding: 1,
|
||||||
};
|
};
|
||||||
let conv_in = nn::conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
|
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
|
||||||
|
|
||||||
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
|
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
|
||||||
let time_embedding =
|
let time_embedding =
|
||||||
@ -263,7 +265,8 @@ impl UNet2DConditionModel {
|
|||||||
config.norm_eps,
|
config.norm_eps,
|
||||||
vs.pp("conv_norm_out"),
|
vs.pp("conv_norm_out"),
|
||||||
)?;
|
)?;
|
||||||
let conv_out = nn::conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
|
let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "unet2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv_in,
|
conv_in,
|
||||||
time_proj,
|
time_proj,
|
||||||
@ -273,18 +276,18 @@ impl UNet2DConditionModel {
|
|||||||
up_blocks,
|
up_blocks,
|
||||||
conv_norm_out,
|
conv_norm_out,
|
||||||
conv_out,
|
conv_out,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl UNet2DConditionModel {
|
|
||||||
pub fn forward(
|
pub fn forward(
|
||||||
&self,
|
&self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
timestep: f64,
|
timestep: f64,
|
||||||
encoder_hidden_states: &Tensor,
|
encoder_hidden_states: &Tensor,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
|
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,13 +5,15 @@ use crate::attention::{
|
|||||||
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
||||||
};
|
};
|
||||||
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
use crate::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||||
|
use crate::utils::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Downsample2D {
|
struct Downsample2D {
|
||||||
conv: Option<nn::Conv2d>,
|
conv: Option<Conv2d>,
|
||||||
padding: usize,
|
padding: usize,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Downsample2D {
|
impl Downsample2D {
|
||||||
@ -24,17 +26,23 @@ impl Downsample2D {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let conv = if use_conv {
|
let conv = if use_conv {
|
||||||
let config = nn::Conv2dConfig { stride: 2, padding };
|
let config = nn::Conv2dConfig { stride: 2, padding };
|
||||||
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||||
Some(conv)
|
Some(conv)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(Downsample2D { conv, padding })
|
let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
|
||||||
|
Ok(Self {
|
||||||
|
conv,
|
||||||
|
padding,
|
||||||
|
span,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Downsample2D {
|
impl Downsample2D {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
match &self.conv {
|
match &self.conv {
|
||||||
None => xs.avg_pool2d((2, 2), (2, 2)),
|
None => xs.avg_pool2d((2, 2), (2, 2)),
|
||||||
Some(conv) => {
|
Some(conv) => {
|
||||||
@ -54,7 +62,8 @@ impl Downsample2D {
|
|||||||
// This does not support the conv-transpose mode.
|
// This does not support the conv-transpose mode.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Upsample2D {
|
struct Upsample2D {
|
||||||
conv: nn::Conv2d,
|
conv: Conv2d,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upsample2D {
|
impl Upsample2D {
|
||||||
@ -63,13 +72,15 @@ impl Upsample2D {
|
|||||||
padding: 1,
|
padding: 1,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv = nn::conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
|
||||||
Ok(Self { conv })
|
let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
|
||||||
|
Ok(Self { conv, span })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Upsample2D {
|
impl Upsample2D {
|
||||||
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = match size {
|
let xs = match size {
|
||||||
None => {
|
None => {
|
||||||
let (_bsize, _channels, h, w) = xs.dims4()?;
|
let (_bsize, _channels, h, w) = xs.dims4()?;
|
||||||
@ -108,6 +119,7 @@ impl Default for DownEncoderBlock2DConfig {
|
|||||||
pub struct DownEncoderBlock2D {
|
pub struct DownEncoderBlock2D {
|
||||||
resnets: Vec<ResnetBlock2D>,
|
resnets: Vec<ResnetBlock2D>,
|
||||||
downsampler: Option<Downsample2D>,
|
downsampler: Option<Downsample2D>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: DownEncoderBlock2DConfig,
|
pub config: DownEncoderBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,9 +159,11 @@ impl DownEncoderBlock2D {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
resnets,
|
resnets,
|
||||||
downsampler,
|
downsampler,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -157,6 +171,7 @@ impl DownEncoderBlock2D {
|
|||||||
|
|
||||||
impl DownEncoderBlock2D {
|
impl DownEncoderBlock2D {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for resnet in self.resnets.iter() {
|
for resnet in self.resnets.iter() {
|
||||||
xs = resnet.forward(&xs, None)?
|
xs = resnet.forward(&xs, None)?
|
||||||
@ -193,6 +208,7 @@ impl Default for UpDecoderBlock2DConfig {
|
|||||||
pub struct UpDecoderBlock2D {
|
pub struct UpDecoderBlock2D {
|
||||||
resnets: Vec<ResnetBlock2D>,
|
resnets: Vec<ResnetBlock2D>,
|
||||||
upsampler: Option<Upsample2D>,
|
upsampler: Option<Upsample2D>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: UpDecoderBlock2DConfig,
|
pub config: UpDecoderBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,9 +243,11 @@ impl UpDecoderBlock2D {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
resnets,
|
resnets,
|
||||||
upsampler,
|
upsampler,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -237,6 +255,7 @@ impl UpDecoderBlock2D {
|
|||||||
|
|
||||||
impl UpDecoderBlock2D {
|
impl UpDecoderBlock2D {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for resnet in self.resnets.iter() {
|
for resnet in self.resnets.iter() {
|
||||||
xs = resnet.forward(&xs, None)?
|
xs = resnet.forward(&xs, None)?
|
||||||
@ -274,6 +293,7 @@ impl Default for UNetMidBlock2DConfig {
|
|||||||
pub struct UNetMidBlock2D {
|
pub struct UNetMidBlock2D {
|
||||||
resnet: ResnetBlock2D,
|
resnet: ResnetBlock2D,
|
||||||
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
|
attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: UNetMidBlock2DConfig,
|
pub config: UNetMidBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -313,14 +333,17 @@ impl UNetMidBlock2D {
|
|||||||
)?;
|
)?;
|
||||||
attn_resnets.push((attn, resnet))
|
attn_resnets.push((attn, resnet))
|
||||||
}
|
}
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "mid2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
resnet,
|
resnet,
|
||||||
attn_resnets,
|
attn_resnets,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = self.resnet.forward(xs, temb)?;
|
let mut xs = self.resnet.forward(xs, temb)?;
|
||||||
for (attn, resnet) in self.attn_resnets.iter() {
|
for (attn, resnet) in self.attn_resnets.iter() {
|
||||||
xs = resnet.forward(&attn.forward(&xs)?, temb)?
|
xs = resnet.forward(&attn.forward(&xs)?, temb)?
|
||||||
@ -361,6 +384,7 @@ impl Default for UNetMidBlock2DCrossAttnConfig {
|
|||||||
pub struct UNetMidBlock2DCrossAttn {
|
pub struct UNetMidBlock2DCrossAttn {
|
||||||
resnet: ResnetBlock2D,
|
resnet: ResnetBlock2D,
|
||||||
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
|
attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: UNetMidBlock2DCrossAttnConfig,
|
pub config: UNetMidBlock2DCrossAttnConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,9 +432,11 @@ impl UNetMidBlock2DCrossAttn {
|
|||||||
)?;
|
)?;
|
||||||
attn_resnets.push((attn, resnet))
|
attn_resnets.push((attn, resnet))
|
||||||
}
|
}
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
resnet,
|
resnet,
|
||||||
attn_resnets,
|
attn_resnets,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -421,6 +447,7 @@ impl UNetMidBlock2DCrossAttn {
|
|||||||
temb: Option<&Tensor>,
|
temb: Option<&Tensor>,
|
||||||
encoder_hidden_states: Option<&Tensor>,
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = self.resnet.forward(xs, temb)?;
|
let mut xs = self.resnet.forward(xs, temb)?;
|
||||||
for (attn, resnet) in self.attn_resnets.iter() {
|
for (attn, resnet) in self.attn_resnets.iter() {
|
||||||
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
|
xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
|
||||||
@ -458,6 +485,7 @@ impl Default for DownBlock2DConfig {
|
|||||||
pub struct DownBlock2D {
|
pub struct DownBlock2D {
|
||||||
resnets: Vec<ResnetBlock2D>,
|
resnets: Vec<ResnetBlock2D>,
|
||||||
downsampler: Option<Downsample2D>,
|
downsampler: Option<Downsample2D>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: DownBlock2DConfig,
|
pub config: DownBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -495,14 +523,17 @@ impl DownBlock2D {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "down2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
resnets,
|
resnets,
|
||||||
downsampler,
|
downsampler,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
|
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
let mut output_states = vec![];
|
let mut output_states = vec![];
|
||||||
for resnet in self.resnets.iter() {
|
for resnet in self.resnets.iter() {
|
||||||
@ -547,6 +578,7 @@ impl Default for CrossAttnDownBlock2DConfig {
|
|||||||
pub struct CrossAttnDownBlock2D {
|
pub struct CrossAttnDownBlock2D {
|
||||||
downblock: DownBlock2D,
|
downblock: DownBlock2D,
|
||||||
attentions: Vec<SpatialTransformer>,
|
attentions: Vec<SpatialTransformer>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: CrossAttnDownBlock2DConfig,
|
pub config: CrossAttnDownBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -585,9 +617,11 @@ impl CrossAttnDownBlock2D {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
downblock,
|
downblock,
|
||||||
attentions,
|
attentions,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -598,6 +632,7 @@ impl CrossAttnDownBlock2D {
|
|||||||
temb: Option<&Tensor>,
|
temb: Option<&Tensor>,
|
||||||
encoder_hidden_states: Option<&Tensor>,
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
) -> Result<(Tensor, Vec<Tensor>)> {
|
) -> Result<(Tensor, Vec<Tensor>)> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut output_states = vec![];
|
let mut output_states = vec![];
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
|
for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
|
||||||
@ -644,6 +679,7 @@ impl Default for UpBlock2DConfig {
|
|||||||
pub struct UpBlock2D {
|
pub struct UpBlock2D {
|
||||||
pub resnets: Vec<ResnetBlock2D>,
|
pub resnets: Vec<ResnetBlock2D>,
|
||||||
upsampler: Option<Upsample2D>,
|
upsampler: Option<Upsample2D>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: UpBlock2DConfig,
|
pub config: UpBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -687,9 +723,11 @@ impl UpBlock2D {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "up2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
resnets,
|
resnets,
|
||||||
upsampler,
|
upsampler,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -701,6 +739,7 @@ impl UpBlock2D {
|
|||||||
temb: Option<&Tensor>,
|
temb: Option<&Tensor>,
|
||||||
upsample_size: Option<(usize, usize)>,
|
upsample_size: Option<(usize, usize)>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for (index, resnet) in self.resnets.iter().enumerate() {
|
for (index, resnet) in self.resnets.iter().enumerate() {
|
||||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||||
@ -739,6 +778,7 @@ impl Default for CrossAttnUpBlock2DConfig {
|
|||||||
pub struct CrossAttnUpBlock2D {
|
pub struct CrossAttnUpBlock2D {
|
||||||
pub upblock: UpBlock2D,
|
pub upblock: UpBlock2D,
|
||||||
pub attentions: Vec<SpatialTransformer>,
|
pub attentions: Vec<SpatialTransformer>,
|
||||||
|
span: tracing::Span,
|
||||||
pub config: CrossAttnUpBlock2DConfig,
|
pub config: CrossAttnUpBlock2DConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -779,9 +819,11 @@ impl CrossAttnUpBlock2D {
|
|||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
upblock,
|
upblock,
|
||||||
attentions,
|
attentions,
|
||||||
|
span,
|
||||||
config,
|
config,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -794,6 +836,7 @@ impl CrossAttnUpBlock2D {
|
|||||||
upsample_size: Option<(usize, usize)>,
|
upsample_size: Option<(usize, usize)>,
|
||||||
encoder_hidden_states: Option<&Tensor>,
|
encoder_hidden_states: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
|
for (index, resnet) in self.upblock.resnets.iter().enumerate() {
|
||||||
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
|
||||||
|
@ -29,3 +29,29 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
|||||||
image.save(p).map_err(candle::Error::wrap)?;
|
image.save(p).map_err(candle::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap the conv2d op to provide some tracing.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Conv2d {
|
||||||
|
inner: candle_nn::Conv2d,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conv2d {
|
||||||
|
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv2d(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
cfg: candle_nn::Conv2dConfig,
|
||||||
|
vs: candle_nn::VarBuilder,
|
||||||
|
) -> Result<Conv2d> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
|
||||||
|
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
||||||
|
Ok(Conv2d { inner, span })
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user