mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
ViT tracing. (#790)
This commit is contained in:
@ -13,18 +13,21 @@ const IN_CHANNELS: usize = 3;
|
||||
struct Conv2dBN {
|
||||
c: candle_nn::Conv2d,
|
||||
bn: candle_nn::BatchNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Conv2dBN {
|
||||
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
|
||||
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
|
||||
Ok(Self { c, bn })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
|
||||
Ok(Self { c, bn, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Conv2dBN {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.c)?.apply(&self.bn)
|
||||
}
|
||||
}
|
||||
@ -33,6 +36,7 @@ impl Module for Conv2dBN {
|
||||
struct PatchEmbed {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
@ -44,12 +48,14 @@ impl PatchEmbed {
|
||||
};
|
||||
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
|
||||
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
|
||||
Ok(Self { conv1, conv2 })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
|
||||
Ok(Self { conv1, conv2, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
@ -59,6 +65,7 @@ struct MBConv {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
conv3: Conv2dBN,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MBConv {
|
||||
@ -72,16 +79,19 @@ impl MBConv {
|
||||
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
|
||||
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
|
||||
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MBConv {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let shortcut = xs;
|
||||
let xs = xs
|
||||
.apply(&self.conv1)?
|
||||
@ -99,6 +109,7 @@ struct PatchMerging {
|
||||
conv2: Conv2dBN,
|
||||
conv3: Conv2dBN,
|
||||
input_resolution: (usize, usize),
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl PatchMerging {
|
||||
@ -118,17 +129,20 @@ impl PatchMerging {
|
||||
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
|
||||
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
|
||||
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
input_resolution,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchMerging {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = if xs.rank() == 3 {
|
||||
let (h, w) = self.input_resolution;
|
||||
let b = xs.dim(0)?;
|
||||
@ -150,6 +164,7 @@ impl Module for PatchMerging {
|
||||
struct ConvLayer {
|
||||
blocks: Vec<MBConv>,
|
||||
downsample: Option<PatchMerging>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ConvLayer {
|
||||
@ -174,12 +189,18 @@ impl ConvLayer {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { blocks, downsample })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
|
||||
Ok(Self {
|
||||
blocks,
|
||||
downsample,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvLayer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
@ -194,21 +215,29 @@ impl Module for ConvLayer {
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
norm: candle_nn::LayerNorm,
|
||||
fc1: candle_nn::Linear,
|
||||
fc2: candle_nn::Linear,
|
||||
fc1: crate::Linear,
|
||||
fc2: crate::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
|
||||
let fc1 = candle_nn::linear(in_, hidden, vb.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(hidden, in_, vb.pp("fc2"))?;
|
||||
Ok(Self { norm, fc1, fc2 })
|
||||
let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
|
||||
let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
Ok(Self {
|
||||
norm,
|
||||
fc1,
|
||||
fc2,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.norm)?
|
||||
.apply(&self.fc1)?
|
||||
.gelu()?
|
||||
@ -219,14 +248,17 @@ impl Module for Mlp {
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
norm: candle_nn::LayerNorm,
|
||||
qkv: candle_nn::Linear,
|
||||
proj: candle_nn::Linear,
|
||||
qkv: crate::Linear,
|
||||
proj: crate::Linear,
|
||||
ab: Tensor,
|
||||
key_dim: usize,
|
||||
num_heads: usize,
|
||||
d: usize,
|
||||
dh: usize,
|
||||
scale: f64,
|
||||
span: tracing::Span,
|
||||
span_matmul: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -243,8 +275,8 @@ impl Attention {
|
||||
let nh_kd = key_dim * num_heads;
|
||||
let h = dh + nh_kd * 2;
|
||||
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||
let qkv = candle_nn::linear(dim, h, vb.pp("qkv"))?;
|
||||
let proj = candle_nn::linear(dh, dim, vb.pp("proj"))?;
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
|
||||
let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
|
||||
|
||||
let points = (0..resolution.0)
|
||||
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
|
||||
@ -265,6 +297,9 @@ impl Attention {
|
||||
attention_biases
|
||||
.index_select(&idxs, 1)?
|
||||
.reshape(((), points.len(), points.len()))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
|
||||
Ok(Self {
|
||||
norm,
|
||||
qkv,
|
||||
@ -275,12 +310,16 @@ impl Attention {
|
||||
d,
|
||||
dh,
|
||||
scale: 1f64 / (key_dim as f64).sqrt(),
|
||||
span,
|
||||
span_matmul,
|
||||
span_softmax,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, n, _) = xs.dims3()?;
|
||||
let xs = xs.apply(&self.norm)?;
|
||||
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
|
||||
@ -296,11 +335,20 @@ impl Module for Attention {
|
||||
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.contiguous()?;
|
||||
let attn = (q.matmul(&k.t()?)? * self.scale)?;
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
(q.matmul(&k.t()?)? * self.scale)?
|
||||
};
|
||||
let attn = attn.broadcast_add(&self.ab)?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
attn.matmul(&v)?
|
||||
.transpose(1, 2)?
|
||||
let attn = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn)?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
attn.matmul(&v)?
|
||||
};
|
||||
attn.transpose(1, 2)?
|
||||
.reshape((b, n, self.dh))?
|
||||
.apply(&self.proj)
|
||||
}
|
||||
@ -313,6 +361,7 @@ struct TinyViTBlock {
|
||||
mlp: Mlp,
|
||||
window_size: usize,
|
||||
input_resolution: (usize, usize),
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl TinyViTBlock {
|
||||
@ -339,18 +388,21 @@ impl TinyViTBlock {
|
||||
..Default::default()
|
||||
};
|
||||
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
Ok(Self {
|
||||
attn,
|
||||
local_conv,
|
||||
mlp,
|
||||
window_size,
|
||||
input_resolution,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TinyViTBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (h, w) = self.input_resolution;
|
||||
let (b, l, c) = xs.dims3()?;
|
||||
let res_x = xs;
|
||||
@ -410,6 +462,7 @@ impl Module for TinyViTBlock {
|
||||
struct BasicLayer {
|
||||
blocks: Vec<TinyViTBlock>,
|
||||
downsample: Option<PatchMerging>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BasicLayer {
|
||||
@ -442,12 +495,18 @@ impl BasicLayer {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { blocks, downsample })
|
||||
let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
|
||||
Ok(Self {
|
||||
blocks,
|
||||
downsample,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BasicLayer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = xs.clone();
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
@ -470,6 +529,8 @@ pub struct TinyViT {
|
||||
neck_ln1: crate::LayerNorm2d,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: crate::LayerNorm2d,
|
||||
span: tracing::Span,
|
||||
span_neck: tracing::Span,
|
||||
}
|
||||
|
||||
impl TinyViT {
|
||||
@ -525,6 +586,8 @@ impl TinyViT {
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
|
||||
|
||||
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
|
||||
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
layer0,
|
||||
@ -533,18 +596,22 @@ impl TinyViT {
|
||||
neck_ln1,
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
span,
|
||||
span_neck,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TinyViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.patch_embed.forward(xs)?;
|
||||
let mut xs = self.layer0.forward(&xs)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs)?
|
||||
}
|
||||
let (b, _, c) = xs.dims3()?;
|
||||
let _enter = self.span_neck.enter();
|
||||
xs.reshape((b, 64, 64, c))?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.apply(&self.neck_conv1)?
|
||||
|
Reference in New Issue
Block a user