mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Track the conv2d operations in stable-diffusion. (#431)
* Track the conv2d operations in stable-diffusion. * Add more tracing to stable-diffusion. * Also trace the resnet bits. * Trace the attention blocks. * Also trace the attention inner part. * Small tweak.
This commit is contained in:
@ -6,17 +6,20 @@ use candle_nn as nn;
|
||||
#[derive(Debug)]
|
||||
struct GeGlu {
|
||||
proj: nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn new(vs: nn::VarBuilder, dim_in: usize, dim_out: usize) -> Result<Self> {
|
||||
let proj = nn::linear(dim_in, dim_out * 2, vs.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "geglu");
|
||||
Ok(Self { proj, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||
&hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()?
|
||||
}
|
||||
@ -27,6 +30,7 @@ impl GeGlu {
|
||||
struct FeedForward {
|
||||
project_in: GeGlu,
|
||||
linear: nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
@ -40,12 +44,18 @@ impl FeedForward {
|
||||
let vs = vs.pp("net");
|
||||
let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?;
|
||||
let linear = nn::linear(inner_dim, dim_out, vs.pp("2"))?;
|
||||
Ok(Self { project_in, linear })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "ff");
|
||||
Ok(Self {
|
||||
project_in,
|
||||
linear,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.project_in.forward(xs)?;
|
||||
self.linear.forward(&xs)
|
||||
}
|
||||
@ -60,6 +70,8 @@ struct CrossAttention {
|
||||
heads: usize,
|
||||
scale: f64,
|
||||
slice_size: Option<usize>,
|
||||
span: tracing::Span,
|
||||
span_attn: tracing::Span,
|
||||
}
|
||||
|
||||
impl CrossAttention {
|
||||
@ -79,6 +91,8 @@ impl CrossAttention {
|
||||
let to_k = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_k"))?;
|
||||
let to_v = nn::linear_no_bias(context_dim, inner_dim, vs.pp("to_v"))?;
|
||||
let to_out = nn::linear(inner_dim, query_dim, vs.pp("to_out.0"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "xa");
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn");
|
||||
Ok(Self {
|
||||
to_q,
|
||||
to_k,
|
||||
@ -87,6 +101,8 @@ impl CrossAttention {
|
||||
heads,
|
||||
scale,
|
||||
slice_size,
|
||||
span,
|
||||
span_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -129,12 +145,14 @@ impl CrossAttention {
|
||||
}
|
||||
|
||||
fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let xs = query.matmul(&(key.transpose(D::Minus1, D::Minus2)? * self.scale)?)?;
|
||||
let xs = nn::ops::softmax(&xs, D::Minus1)?.matmul(value)?;
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs);
|
||||
let key = self.to_k.forward(context)?;
|
||||
@ -165,6 +183,7 @@ struct BasicTransformerBlock {
|
||||
norm1: nn::LayerNorm,
|
||||
norm2: nn::LayerNorm,
|
||||
norm3: nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BasicTransformerBlock {
|
||||
@ -196,6 +215,7 @@ impl BasicTransformerBlock {
|
||||
let norm1 = nn::layer_norm(dim, 1e-5, vs.pp("norm1"))?;
|
||||
let norm2 = nn::layer_norm(dim, 1e-5, vs.pp("norm2"))?;
|
||||
let norm3 = nn::layer_norm(dim, 1e-5, vs.pp("norm3"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "basic-transformer");
|
||||
Ok(Self {
|
||||
attn1,
|
||||
ff,
|
||||
@ -203,10 +223,12 @@ impl BasicTransformerBlock {
|
||||
norm1,
|
||||
norm2,
|
||||
norm3,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = (self.attn1.forward(&self.norm1.forward(xs)?, None)? + xs)?;
|
||||
let xs = (self.attn2.forward(&self.norm2.forward(&xs)?, context)? + xs)?;
|
||||
self.ff.forward(&self.norm3.forward(&xs)?)? + xs
|
||||
@ -247,6 +269,7 @@ pub struct SpatialTransformer {
|
||||
proj_in: Proj,
|
||||
transformer_blocks: Vec<BasicTransformerBlock>,
|
||||
proj_out: Proj,
|
||||
span: tracing::Span,
|
||||
pub config: SpatialTransformerConfig,
|
||||
}
|
||||
|
||||
@ -295,16 +318,19 @@ impl SpatialTransformer {
|
||||
vs.pp("proj_out"),
|
||||
)?)
|
||||
};
|
||||
let span = tracing::span!(tracing::Level::TRACE, "spatial-transformer");
|
||||
Ok(Self {
|
||||
norm,
|
||||
proj_in,
|
||||
transformer_blocks,
|
||||
proj_out,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (batch, _channel, height, weight) = xs.dims4()?;
|
||||
let residual = xs;
|
||||
let xs = self.norm.forward(xs)?;
|
||||
@ -376,6 +402,7 @@ pub struct AttentionBlock {
|
||||
proj_attn: nn::Linear,
|
||||
channels: usize,
|
||||
num_heads: usize,
|
||||
span: tracing::Span,
|
||||
config: AttentionBlockConfig,
|
||||
}
|
||||
|
||||
@ -389,6 +416,7 @@ impl AttentionBlock {
|
||||
let key = nn::linear(channels, channels, vs.pp("key"))?;
|
||||
let value = nn::linear(channels, channels, vs.pp("value"))?;
|
||||
let proj_attn = nn::linear(channels, channels, vs.pp("proj_attn"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn-block");
|
||||
Ok(Self {
|
||||
group_norm,
|
||||
query,
|
||||
@ -397,6 +425,7 @@ impl AttentionBlock {
|
||||
proj_attn,
|
||||
channels,
|
||||
num_heads,
|
||||
span,
|
||||
config,
|
||||
})
|
||||
}
|
||||
@ -406,10 +435,9 @@ impl AttentionBlock {
|
||||
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
impl AttentionBlock {
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let (batch, channel, height, width) = xs.dims4()?;
|
||||
let xs = self
|
||||
|
Reference in New Issue
Block a user