ViT tracing. (#790)

This commit is contained in:
Laurent Mazare
2023-09-09 17:26:39 +01:00
committed by GitHub
parent 74ad4deb42
commit 31936c08fe

View File

@ -13,18 +13,21 @@ const IN_CHANNELS: usize = 3;
struct Conv2dBN {
c: candle_nn::Conv2d,
bn: candle_nn::BatchNorm,
span: tracing::Span,
}
impl Conv2dBN {
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
Ok(Self { c, bn })
let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
Ok(Self { c, bn, span })
}
}
impl Module for Conv2dBN {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.c)?.apply(&self.bn)
}
}
@ -33,6 +36,7 @@ impl Module for Conv2dBN {
struct PatchEmbed {
conv1: Conv2dBN,
conv2: Conv2dBN,
span: tracing::Span,
}
impl PatchEmbed {
@ -44,12 +48,14 @@ impl PatchEmbed {
};
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
Ok(Self { conv1, conv2 })
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
Ok(Self { conv1, conv2, span })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
}
}
@ -59,6 +65,7 @@ struct MBConv {
conv1: Conv2dBN,
conv2: Conv2dBN,
conv3: Conv2dBN,
span: tracing::Span,
}
impl MBConv {
@ -72,16 +79,19 @@ impl MBConv {
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
Ok(Self {
conv1,
conv2,
conv3,
span,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut = xs;
let xs = xs
.apply(&self.conv1)?
@ -99,6 +109,7 @@ struct PatchMerging {
conv2: Conv2dBN,
conv3: Conv2dBN,
input_resolution: (usize, usize),
span: tracing::Span,
}
impl PatchMerging {
@ -118,17 +129,20 @@ impl PatchMerging {
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
Ok(Self {
conv1,
conv2,
conv3,
input_resolution,
span,
})
}
}
impl Module for PatchMerging {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = if xs.rank() == 3 {
let (h, w) = self.input_resolution;
let b = xs.dim(0)?;
@ -150,6 +164,7 @@ impl Module for PatchMerging {
struct ConvLayer {
blocks: Vec<MBConv>,
downsample: Option<PatchMerging>,
span: tracing::Span,
}
impl ConvLayer {
@ -174,12 +189,18 @@ impl ConvLayer {
} else {
None
};
Ok(Self { blocks, downsample })
let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
Ok(Self {
blocks,
downsample,
span,
})
}
}
impl Module for ConvLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for block in self.blocks.iter() {
xs = block.forward(&xs)?
@ -194,21 +215,29 @@ impl Module for ConvLayer {
#[derive(Debug)]
struct Mlp {
norm: candle_nn::LayerNorm,
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
fc1: crate::Linear,
fc2: crate::Linear,
span: tracing::Span,
}
impl Mlp {
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
let fc1 = candle_nn::linear(in_, hidden, vb.pp("fc1"))?;
let fc2 = candle_nn::linear(hidden, in_, vb.pp("fc2"))?;
Ok(Self { norm, fc1, fc2 })
let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
Ok(Self {
norm,
fc1,
fc2,
span,
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.norm)?
.apply(&self.fc1)?
.gelu()?
@ -219,14 +248,17 @@ impl Module for Mlp {
#[derive(Debug)]
struct Attention {
norm: candle_nn::LayerNorm,
qkv: candle_nn::Linear,
proj: candle_nn::Linear,
qkv: crate::Linear,
proj: crate::Linear,
ab: Tensor,
key_dim: usize,
num_heads: usize,
d: usize,
dh: usize,
scale: f64,
span: tracing::Span,
span_matmul: tracing::Span,
span_softmax: tracing::Span,
}
impl Attention {
@ -243,8 +275,8 @@ impl Attention {
let nh_kd = key_dim * num_heads;
let h = dh + nh_kd * 2;
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
let qkv = candle_nn::linear(dim, h, vb.pp("qkv"))?;
let proj = candle_nn::linear(dh, dim, vb.pp("proj"))?;
let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
let points = (0..resolution.0)
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
@ -265,6 +297,9 @@ impl Attention {
attention_biases
.index_select(&idxs, 1)?
.reshape(((), points.len(), points.len()))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
Ok(Self {
norm,
qkv,
@ -275,12 +310,16 @@ impl Attention {
d,
dh,
scale: 1f64 / (key_dim as f64).sqrt(),
span,
span_matmul,
span_softmax,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, n, _) = xs.dims3()?;
let xs = xs.apply(&self.norm)?;
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
@ -296,11 +335,20 @@ impl Module for Attention {
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let attn = (q.matmul(&k.t()?)? * self.scale)?;
let attn = {
let _enter = self.span_matmul.enter();
(q.matmul(&k.t()?)? * self.scale)?
};
let attn = attn.broadcast_add(&self.ab)?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
attn.matmul(&v)?
.transpose(1, 2)?
let attn = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = {
let _enter = self.span_matmul.enter();
attn.matmul(&v)?
};
attn.transpose(1, 2)?
.reshape((b, n, self.dh))?
.apply(&self.proj)
}
@ -313,6 +361,7 @@ struct TinyViTBlock {
mlp: Mlp,
window_size: usize,
input_resolution: (usize, usize),
span: tracing::Span,
}
impl TinyViTBlock {
@ -339,18 +388,21 @@ impl TinyViTBlock {
..Default::default()
};
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
Ok(Self {
attn,
local_conv,
mlp,
window_size,
input_resolution,
span,
})
}
}
impl Module for TinyViTBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (h, w) = self.input_resolution;
let (b, l, c) = xs.dims3()?;
let res_x = xs;
@ -410,6 +462,7 @@ impl Module for TinyViTBlock {
struct BasicLayer {
blocks: Vec<TinyViTBlock>,
downsample: Option<PatchMerging>,
span: tracing::Span,
}
impl BasicLayer {
@ -442,12 +495,18 @@ impl BasicLayer {
} else {
None
};
Ok(Self { blocks, downsample })
let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
Ok(Self {
blocks,
downsample,
span,
})
}
}
impl Module for BasicLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for block in self.blocks.iter() {
xs = block.forward(&xs)?
@ -470,6 +529,8 @@ pub struct TinyViT {
neck_ln1: crate::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
span: tracing::Span,
span_neck: tracing::Span,
}
impl TinyViT {
@ -525,6 +586,8 @@ impl TinyViT {
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
Ok(Self {
patch_embed,
layer0,
@ -533,18 +596,22 @@ impl TinyViT {
neck_ln1,
neck_conv2,
neck_ln2,
span,
span_neck,
})
}
}
impl Module for TinyViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?;
let mut xs = self.layer0.forward(&xs)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
let (b, _, c) = xs.dims3()?;
let _enter = self.span_neck.enter();
xs.reshape((b, 64, 64, c))?
.permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)?