ViT tracing. (#790)

This commit is contained in:
Laurent Mazare
2023-09-09 17:26:39 +01:00
committed by GitHub
parent 74ad4deb42
commit 31936c08fe

View File

@ -13,18 +13,21 @@ const IN_CHANNELS: usize = 3;
struct Conv2dBN { struct Conv2dBN {
c: candle_nn::Conv2d, c: candle_nn::Conv2d,
bn: candle_nn::BatchNorm, bn: candle_nn::BatchNorm,
span: tracing::Span,
} }
impl Conv2dBN { impl Conv2dBN {
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> { fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?; let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?; let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
Ok(Self { c, bn }) let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
Ok(Self { c, bn, span })
} }
} }
impl Module for Conv2dBN { impl Module for Conv2dBN {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.c)?.apply(&self.bn) xs.apply(&self.c)?.apply(&self.bn)
} }
} }
@ -33,6 +36,7 @@ impl Module for Conv2dBN {
struct PatchEmbed { struct PatchEmbed {
conv1: Conv2dBN, conv1: Conv2dBN,
conv2: Conv2dBN, conv2: Conv2dBN,
span: tracing::Span,
} }
impl PatchEmbed { impl PatchEmbed {
@ -44,12 +48,14 @@ impl PatchEmbed {
}; };
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?; let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?; let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
Ok(Self { conv1, conv2 }) let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
Ok(Self { conv1, conv2, span })
} }
} }
impl Module for PatchEmbed { impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2) xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
} }
} }
@ -59,6 +65,7 @@ struct MBConv {
conv1: Conv2dBN, conv1: Conv2dBN,
conv2: Conv2dBN, conv2: Conv2dBN,
conv3: Conv2dBN, conv3: Conv2dBN,
span: tracing::Span,
} }
impl MBConv { impl MBConv {
@ -72,16 +79,19 @@ impl MBConv {
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?; let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?; let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?; let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
Ok(Self { Ok(Self {
conv1, conv1,
conv2, conv2,
conv3, conv3,
span,
}) })
} }
} }
impl Module for MBConv { impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut = xs; let shortcut = xs;
let xs = xs let xs = xs
.apply(&self.conv1)? .apply(&self.conv1)?
@ -99,6 +109,7 @@ struct PatchMerging {
conv2: Conv2dBN, conv2: Conv2dBN,
conv3: Conv2dBN, conv3: Conv2dBN,
input_resolution: (usize, usize), input_resolution: (usize, usize),
span: tracing::Span,
} }
impl PatchMerging { impl PatchMerging {
@ -118,17 +129,20 @@ impl PatchMerging {
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?; let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?; let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?; let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
Ok(Self { Ok(Self {
conv1, conv1,
conv2, conv2,
conv3, conv3,
input_resolution, input_resolution,
span,
}) })
} }
} }
impl Module for PatchMerging { impl Module for PatchMerging {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = if xs.rank() == 3 { let xs = if xs.rank() == 3 {
let (h, w) = self.input_resolution; let (h, w) = self.input_resolution;
let b = xs.dim(0)?; let b = xs.dim(0)?;
@ -150,6 +164,7 @@ impl Module for PatchMerging {
struct ConvLayer { struct ConvLayer {
blocks: Vec<MBConv>, blocks: Vec<MBConv>,
downsample: Option<PatchMerging>, downsample: Option<PatchMerging>,
span: tracing::Span,
} }
impl ConvLayer { impl ConvLayer {
@ -174,12 +189,18 @@ impl ConvLayer {
} else { } else {
None None
}; };
Ok(Self { blocks, downsample }) let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
Ok(Self {
blocks,
downsample,
span,
})
} }
} }
impl Module for ConvLayer { impl Module for ConvLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for block in self.blocks.iter() { for block in self.blocks.iter() {
xs = block.forward(&xs)? xs = block.forward(&xs)?
@ -194,21 +215,29 @@ impl Module for ConvLayer {
#[derive(Debug)] #[derive(Debug)]
struct Mlp { struct Mlp {
norm: candle_nn::LayerNorm, norm: candle_nn::LayerNorm,
fc1: candle_nn::Linear, fc1: crate::Linear,
fc2: candle_nn::Linear, fc2: crate::Linear,
span: tracing::Span,
} }
impl Mlp { impl Mlp {
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> { fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?; let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
let fc1 = candle_nn::linear(in_, hidden, vb.pp("fc1"))?; let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
let fc2 = candle_nn::linear(hidden, in_, vb.pp("fc2"))?; let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
Ok(Self { norm, fc1, fc2 }) let span = tracing::span!(tracing::Level::TRACE, "mlp");
Ok(Self {
norm,
fc1,
fc2,
span,
})
} }
} }
impl Module for Mlp { impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.norm)? xs.apply(&self.norm)?
.apply(&self.fc1)? .apply(&self.fc1)?
.gelu()? .gelu()?
@ -219,14 +248,17 @@ impl Module for Mlp {
#[derive(Debug)] #[derive(Debug)]
struct Attention { struct Attention {
norm: candle_nn::LayerNorm, norm: candle_nn::LayerNorm,
qkv: candle_nn::Linear, qkv: crate::Linear,
proj: candle_nn::Linear, proj: crate::Linear,
ab: Tensor, ab: Tensor,
key_dim: usize, key_dim: usize,
num_heads: usize, num_heads: usize,
d: usize, d: usize,
dh: usize, dh: usize,
scale: f64, scale: f64,
span: tracing::Span,
span_matmul: tracing::Span,
span_softmax: tracing::Span,
} }
impl Attention { impl Attention {
@ -243,8 +275,8 @@ impl Attention {
let nh_kd = key_dim * num_heads; let nh_kd = key_dim * num_heads;
let h = dh + nh_kd * 2; let h = dh + nh_kd * 2;
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?; let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
let qkv = candle_nn::linear(dim, h, vb.pp("qkv"))?; let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
let proj = candle_nn::linear(dh, dim, vb.pp("proj"))?; let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
let points = (0..resolution.0) let points = (0..resolution.0)
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64))) .flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
@ -265,6 +297,9 @@ impl Attention {
attention_biases attention_biases
.index_select(&idxs, 1)? .index_select(&idxs, 1)?
.reshape(((), points.len(), points.len()))?; .reshape(((), points.len(), points.len()))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
Ok(Self { Ok(Self {
norm, norm,
qkv, qkv,
@ -275,12 +310,16 @@ impl Attention {
d, d,
dh, dh,
scale: 1f64 / (key_dim as f64).sqrt(), scale: 1f64 / (key_dim as f64).sqrt(),
span,
span_matmul,
span_softmax,
}) })
} }
} }
impl Module for Attention { impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, n, _) = xs.dims3()?; let (b, n, _) = xs.dims3()?;
let xs = xs.apply(&self.norm)?; let xs = xs.apply(&self.norm)?;
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?; let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
@ -296,11 +335,20 @@ impl Module for Attention {
.narrow(D::Minus1, 2 * self.key_dim, self.d)? .narrow(D::Minus1, 2 * self.key_dim, self.d)?
.permute((0, 2, 1, 3))? .permute((0, 2, 1, 3))?
.contiguous()?; .contiguous()?;
let attn = (q.matmul(&k.t()?)? * self.scale)?; let attn = {
let _enter = self.span_matmul.enter();
(q.matmul(&k.t()?)? * self.scale)?
};
let attn = attn.broadcast_add(&self.ab)?; let attn = attn.broadcast_add(&self.ab)?;
let attn = candle_nn::ops::softmax_last_dim(&attn)?; let attn = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = {
let _enter = self.span_matmul.enter();
attn.matmul(&v)? attn.matmul(&v)?
.transpose(1, 2)? };
attn.transpose(1, 2)?
.reshape((b, n, self.dh))? .reshape((b, n, self.dh))?
.apply(&self.proj) .apply(&self.proj)
} }
@ -313,6 +361,7 @@ struct TinyViTBlock {
mlp: Mlp, mlp: Mlp,
window_size: usize, window_size: usize,
input_resolution: (usize, usize), input_resolution: (usize, usize),
span: tracing::Span,
} }
impl TinyViTBlock { impl TinyViTBlock {
@ -339,18 +388,21 @@ impl TinyViTBlock {
..Default::default() ..Default::default()
}; };
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?; let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
Ok(Self { Ok(Self {
attn, attn,
local_conv, local_conv,
mlp, mlp,
window_size, window_size,
input_resolution, input_resolution,
span,
}) })
} }
} }
impl Module for TinyViTBlock { impl Module for TinyViTBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (h, w) = self.input_resolution; let (h, w) = self.input_resolution;
let (b, l, c) = xs.dims3()?; let (b, l, c) = xs.dims3()?;
let res_x = xs; let res_x = xs;
@ -410,6 +462,7 @@ impl Module for TinyViTBlock {
struct BasicLayer { struct BasicLayer {
blocks: Vec<TinyViTBlock>, blocks: Vec<TinyViTBlock>,
downsample: Option<PatchMerging>, downsample: Option<PatchMerging>,
span: tracing::Span,
} }
impl BasicLayer { impl BasicLayer {
@ -442,12 +495,18 @@ impl BasicLayer {
} else { } else {
None None
}; };
Ok(Self { blocks, downsample }) let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
Ok(Self {
blocks,
downsample,
span,
})
} }
} }
impl Module for BasicLayer { impl Module for BasicLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone(); let mut xs = xs.clone();
for block in self.blocks.iter() { for block in self.blocks.iter() {
xs = block.forward(&xs)? xs = block.forward(&xs)?
@ -470,6 +529,8 @@ pub struct TinyViT {
neck_ln1: crate::LayerNorm2d, neck_ln1: crate::LayerNorm2d,
neck_conv2: candle_nn::Conv2d, neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d, neck_ln2: crate::LayerNorm2d,
span: tracing::Span,
span_neck: tracing::Span,
} }
impl TinyViT { impl TinyViT {
@ -525,6 +586,8 @@ impl TinyViT {
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?; let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?; let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
Ok(Self { Ok(Self {
patch_embed, patch_embed,
layer0, layer0,
@ -533,18 +596,22 @@ impl TinyViT {
neck_ln1, neck_ln1,
neck_conv2, neck_conv2,
neck_ln2, neck_ln2,
span,
span_neck,
}) })
} }
} }
impl Module for TinyViT { impl Module for TinyViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?; let xs = self.patch_embed.forward(xs)?;
let mut xs = self.layer0.forward(&xs)?; let mut xs = self.layer0.forward(&xs)?;
for layer in self.layers.iter() { for layer in self.layers.iter() {
xs = layer.forward(&xs)? xs = layer.forward(&xs)?
} }
let (b, _, c) = xs.dims3()?; let (b, _, c) = xs.dims3()?;
let _enter = self.span_neck.enter();
xs.reshape((b, 64, 64, c))? xs.reshape((b, 64, 64, c))?
.permute((0, 3, 1, 2))? .permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)? .apply(&self.neck_conv1)?