mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
ViT tracing. (#790)
This commit is contained in:
@ -13,18 +13,21 @@ const IN_CHANNELS: usize = 3;
|
|||||||
struct Conv2dBN {
|
struct Conv2dBN {
|
||||||
c: candle_nn::Conv2d,
|
c: candle_nn::Conv2d,
|
||||||
bn: candle_nn::BatchNorm,
|
bn: candle_nn::BatchNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Conv2dBN {
|
impl Conv2dBN {
|
||||||
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
|
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
|
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
|
||||||
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
|
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
|
||||||
Ok(Self { c, bn })
|
let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
|
||||||
|
Ok(Self { c, bn, span })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Conv2dBN {
|
impl Module for Conv2dBN {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
xs.apply(&self.c)?.apply(&self.bn)
|
xs.apply(&self.c)?.apply(&self.bn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -33,6 +36,7 @@ impl Module for Conv2dBN {
|
|||||||
struct PatchEmbed {
|
struct PatchEmbed {
|
||||||
conv1: Conv2dBN,
|
conv1: Conv2dBN,
|
||||||
conv2: Conv2dBN,
|
conv2: Conv2dBN,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PatchEmbed {
|
impl PatchEmbed {
|
||||||
@ -44,12 +48,14 @@ impl PatchEmbed {
|
|||||||
};
|
};
|
||||||
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
|
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
|
||||||
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
|
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
|
||||||
Ok(Self { conv1, conv2 })
|
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
|
||||||
|
Ok(Self { conv1, conv2, span })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for PatchEmbed {
|
impl Module for PatchEmbed {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
|
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -59,6 +65,7 @@ struct MBConv {
|
|||||||
conv1: Conv2dBN,
|
conv1: Conv2dBN,
|
||||||
conv2: Conv2dBN,
|
conv2: Conv2dBN,
|
||||||
conv3: Conv2dBN,
|
conv3: Conv2dBN,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MBConv {
|
impl MBConv {
|
||||||
@ -72,16 +79,19 @@ impl MBConv {
|
|||||||
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
|
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
|
||||||
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
|
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
|
||||||
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
|
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
conv3,
|
conv3,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for MBConv {
|
impl Module for MBConv {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let shortcut = xs;
|
let shortcut = xs;
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.apply(&self.conv1)?
|
.apply(&self.conv1)?
|
||||||
@ -99,6 +109,7 @@ struct PatchMerging {
|
|||||||
conv2: Conv2dBN,
|
conv2: Conv2dBN,
|
||||||
conv3: Conv2dBN,
|
conv3: Conv2dBN,
|
||||||
input_resolution: (usize, usize),
|
input_resolution: (usize, usize),
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PatchMerging {
|
impl PatchMerging {
|
||||||
@ -118,17 +129,20 @@ impl PatchMerging {
|
|||||||
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
|
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
|
||||||
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
|
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
|
||||||
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
|
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
conv3,
|
conv3,
|
||||||
input_resolution,
|
input_resolution,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for PatchMerging {
|
impl Module for PatchMerging {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = if xs.rank() == 3 {
|
let xs = if xs.rank() == 3 {
|
||||||
let (h, w) = self.input_resolution;
|
let (h, w) = self.input_resolution;
|
||||||
let b = xs.dim(0)?;
|
let b = xs.dim(0)?;
|
||||||
@ -150,6 +164,7 @@ impl Module for PatchMerging {
|
|||||||
struct ConvLayer {
|
struct ConvLayer {
|
||||||
blocks: Vec<MBConv>,
|
blocks: Vec<MBConv>,
|
||||||
downsample: Option<PatchMerging>,
|
downsample: Option<PatchMerging>,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConvLayer {
|
impl ConvLayer {
|
||||||
@ -174,12 +189,18 @@ impl ConvLayer {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(Self { blocks, downsample })
|
let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
|
||||||
|
Ok(Self {
|
||||||
|
blocks,
|
||||||
|
downsample,
|
||||||
|
span,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for ConvLayer {
|
impl Module for ConvLayer {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
xs = block.forward(&xs)?
|
xs = block.forward(&xs)?
|
||||||
@ -194,21 +215,29 @@ impl Module for ConvLayer {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
norm: candle_nn::LayerNorm,
|
norm: candle_nn::LayerNorm,
|
||||||
fc1: candle_nn::Linear,
|
fc1: crate::Linear,
|
||||||
fc2: candle_nn::Linear,
|
fc2: crate::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
|
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
|
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
|
||||||
let fc1 = candle_nn::linear(in_, hidden, vb.pp("fc1"))?;
|
let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
|
||||||
let fc2 = candle_nn::linear(hidden, in_, vb.pp("fc2"))?;
|
let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
|
||||||
Ok(Self { norm, fc1, fc2 })
|
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||||
|
Ok(Self {
|
||||||
|
norm,
|
||||||
|
fc1,
|
||||||
|
fc2,
|
||||||
|
span,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Mlp {
|
impl Module for Mlp {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
xs.apply(&self.norm)?
|
xs.apply(&self.norm)?
|
||||||
.apply(&self.fc1)?
|
.apply(&self.fc1)?
|
||||||
.gelu()?
|
.gelu()?
|
||||||
@ -219,14 +248,17 @@ impl Module for Mlp {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Attention {
|
struct Attention {
|
||||||
norm: candle_nn::LayerNorm,
|
norm: candle_nn::LayerNorm,
|
||||||
qkv: candle_nn::Linear,
|
qkv: crate::Linear,
|
||||||
proj: candle_nn::Linear,
|
proj: crate::Linear,
|
||||||
ab: Tensor,
|
ab: Tensor,
|
||||||
key_dim: usize,
|
key_dim: usize,
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
d: usize,
|
d: usize,
|
||||||
dh: usize,
|
dh: usize,
|
||||||
scale: f64,
|
scale: f64,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_matmul: tracing::Span,
|
||||||
|
span_softmax: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
@ -243,8 +275,8 @@ impl Attention {
|
|||||||
let nh_kd = key_dim * num_heads;
|
let nh_kd = key_dim * num_heads;
|
||||||
let h = dh + nh_kd * 2;
|
let h = dh + nh_kd * 2;
|
||||||
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||||
let qkv = candle_nn::linear(dim, h, vb.pp("qkv"))?;
|
let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
|
||||||
let proj = candle_nn::linear(dh, dim, vb.pp("proj"))?;
|
let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
|
||||||
|
|
||||||
let points = (0..resolution.0)
|
let points = (0..resolution.0)
|
||||||
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
|
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
|
||||||
@ -265,6 +297,9 @@ impl Attention {
|
|||||||
attention_biases
|
attention_biases
|
||||||
.index_select(&idxs, 1)?
|
.index_select(&idxs, 1)?
|
||||||
.reshape(((), points.len(), points.len()))?;
|
.reshape(((), points.len(), points.len()))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||||
|
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
|
||||||
|
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
norm,
|
norm,
|
||||||
qkv,
|
qkv,
|
||||||
@ -275,12 +310,16 @@ impl Attention {
|
|||||||
d,
|
d,
|
||||||
dh,
|
dh,
|
||||||
scale: 1f64 / (key_dim as f64).sqrt(),
|
scale: 1f64 / (key_dim as f64).sqrt(),
|
||||||
|
span,
|
||||||
|
span_matmul,
|
||||||
|
span_softmax,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Attention {
|
impl Module for Attention {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b, n, _) = xs.dims3()?;
|
let (b, n, _) = xs.dims3()?;
|
||||||
let xs = xs.apply(&self.norm)?;
|
let xs = xs.apply(&self.norm)?;
|
||||||
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
|
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
|
||||||
@ -296,11 +335,20 @@ impl Module for Attention {
|
|||||||
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
|
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
|
||||||
.permute((0, 2, 1, 3))?
|
.permute((0, 2, 1, 3))?
|
||||||
.contiguous()?;
|
.contiguous()?;
|
||||||
let attn = (q.matmul(&k.t()?)? * self.scale)?;
|
let attn = {
|
||||||
|
let _enter = self.span_matmul.enter();
|
||||||
|
(q.matmul(&k.t()?)? * self.scale)?
|
||||||
|
};
|
||||||
let attn = attn.broadcast_add(&self.ab)?;
|
let attn = attn.broadcast_add(&self.ab)?;
|
||||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
let attn = {
|
||||||
attn.matmul(&v)?
|
let _enter = self.span_softmax.enter();
|
||||||
.transpose(1, 2)?
|
candle_nn::ops::softmax_last_dim(&attn)?
|
||||||
|
};
|
||||||
|
let attn = {
|
||||||
|
let _enter = self.span_matmul.enter();
|
||||||
|
attn.matmul(&v)?
|
||||||
|
};
|
||||||
|
attn.transpose(1, 2)?
|
||||||
.reshape((b, n, self.dh))?
|
.reshape((b, n, self.dh))?
|
||||||
.apply(&self.proj)
|
.apply(&self.proj)
|
||||||
}
|
}
|
||||||
@ -313,6 +361,7 @@ struct TinyViTBlock {
|
|||||||
mlp: Mlp,
|
mlp: Mlp,
|
||||||
window_size: usize,
|
window_size: usize,
|
||||||
input_resolution: (usize, usize),
|
input_resolution: (usize, usize),
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TinyViTBlock {
|
impl TinyViTBlock {
|
||||||
@ -339,18 +388,21 @@ impl TinyViTBlock {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
|
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
attn,
|
attn,
|
||||||
local_conv,
|
local_conv,
|
||||||
mlp,
|
mlp,
|
||||||
window_size,
|
window_size,
|
||||||
input_resolution,
|
input_resolution,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for TinyViTBlock {
|
impl Module for TinyViTBlock {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (h, w) = self.input_resolution;
|
let (h, w) = self.input_resolution;
|
||||||
let (b, l, c) = xs.dims3()?;
|
let (b, l, c) = xs.dims3()?;
|
||||||
let res_x = xs;
|
let res_x = xs;
|
||||||
@ -410,6 +462,7 @@ impl Module for TinyViTBlock {
|
|||||||
struct BasicLayer {
|
struct BasicLayer {
|
||||||
blocks: Vec<TinyViTBlock>,
|
blocks: Vec<TinyViTBlock>,
|
||||||
downsample: Option<PatchMerging>,
|
downsample: Option<PatchMerging>,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BasicLayer {
|
impl BasicLayer {
|
||||||
@ -442,12 +495,18 @@ impl BasicLayer {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(Self { blocks, downsample })
|
let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
|
||||||
|
Ok(Self {
|
||||||
|
blocks,
|
||||||
|
downsample,
|
||||||
|
span,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for BasicLayer {
|
impl Module for BasicLayer {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
xs = block.forward(&xs)?
|
xs = block.forward(&xs)?
|
||||||
@ -470,6 +529,8 @@ pub struct TinyViT {
|
|||||||
neck_ln1: crate::LayerNorm2d,
|
neck_ln1: crate::LayerNorm2d,
|
||||||
neck_conv2: candle_nn::Conv2d,
|
neck_conv2: candle_nn::Conv2d,
|
||||||
neck_ln2: crate::LayerNorm2d,
|
neck_ln2: crate::LayerNorm2d,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_neck: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TinyViT {
|
impl TinyViT {
|
||||||
@ -525,6 +586,8 @@ impl TinyViT {
|
|||||||
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
|
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
|
||||||
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
|
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
|
||||||
|
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
|
||||||
|
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embed,
|
patch_embed,
|
||||||
layer0,
|
layer0,
|
||||||
@ -533,18 +596,22 @@ impl TinyViT {
|
|||||||
neck_ln1,
|
neck_ln1,
|
||||||
neck_conv2,
|
neck_conv2,
|
||||||
neck_ln2,
|
neck_ln2,
|
||||||
|
span,
|
||||||
|
span_neck,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for TinyViT {
|
impl Module for TinyViT {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let xs = self.patch_embed.forward(xs)?;
|
let xs = self.patch_embed.forward(xs)?;
|
||||||
let mut xs = self.layer0.forward(&xs)?;
|
let mut xs = self.layer0.forward(&xs)?;
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter() {
|
||||||
xs = layer.forward(&xs)?
|
xs = layer.forward(&xs)?
|
||||||
}
|
}
|
||||||
let (b, _, c) = xs.dims3()?;
|
let (b, _, c) = xs.dims3()?;
|
||||||
|
let _enter = self.span_neck.enter();
|
||||||
xs.reshape((b, 64, 64, c))?
|
xs.reshape((b, 64, 64, c))?
|
||||||
.permute((0, 3, 1, 2))?
|
.permute((0, 3, 1, 2))?
|
||||||
.apply(&self.neck_conv1)?
|
.apply(&self.neck_conv1)?
|
||||||
|
Reference in New Issue
Block a user