mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
TinyViT backbone for segment-anything. (#787)
* TinyViT. * More TinyViT. * Add more to the tinyvit backbone. * Proper padding. * Plus ViT. * Add the tiniest vit spec.
This commit is contained in:
@ -11,6 +11,7 @@ pub mod model_image_encoder;
|
||||
pub mod model_mask_decoder;
|
||||
pub mod model_prompt_encoder;
|
||||
pub mod model_sam;
|
||||
pub mod model_tiny_vit;
|
||||
pub mod model_transformer;
|
||||
|
||||
use candle::{DType, Result, Tensor};
|
||||
|
563
candle-examples/examples/segment-anything/model_tiny_vit.rs
Normal file
563
candle-examples/examples/segment-anything/model_tiny_vit.rs
Normal file
@ -0,0 +1,563 @@
|
||||
// Adapted from:
|
||||
// https://github.com/ChaoningZhang/MobileSAM/blob/master/mobile_sam/modeling/tiny_vit_sam.py
|
||||
#![allow(unused)]
|
||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Conv2dConfig, Module, VarBuilder};
|
||||
|
||||
const MBCONV_EXPAND_RATIO: usize = 4;
|
||||
const MLP_RATIO: usize = 4;
|
||||
const LOCAL_CONV_SIZE: usize = 3;
|
||||
const IMG_SIZE: usize = 224;
|
||||
const IN_CHANNELS: usize = 3;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Conv2dBN {
|
||||
c: candle_nn::Conv2d,
|
||||
bn: candle_nn::BatchNorm,
|
||||
}
|
||||
|
||||
impl Conv2dBN {
|
||||
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let c = candle_nn::conv2d(in_, out, ks, cfg, vb.pp("c"))?;
|
||||
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
|
||||
Ok(Self { c, bn })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Conv2dBN {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.c)?.apply(&self.bn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchEmbed {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
}
|
||||
|
||||
impl PatchEmbed {
|
||||
fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
stride: 2,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
|
||||
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
|
||||
Ok(Self { conv1, conv2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchEmbed {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MBConv {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
conv3: Conv2dBN,
|
||||
}
|
||||
|
||||
impl MBConv {
|
||||
fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden = in_ * expand_ratio;
|
||||
let cfg2 = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
groups: hidden,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
|
||||
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
|
||||
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MBConv {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let shortcut = xs;
|
||||
let xs = xs
|
||||
.apply(&self.conv1)?
|
||||
.gelu()?
|
||||
.apply(&self.conv2)?
|
||||
.gelu()?
|
||||
.apply(&self.conv3)?;
|
||||
(xs + shortcut)?.gelu()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PatchMerging {
|
||||
conv1: Conv2dBN,
|
||||
conv2: Conv2dBN,
|
||||
conv3: Conv2dBN,
|
||||
input_resolution: (usize, usize),
|
||||
}
|
||||
|
||||
impl PatchMerging {
|
||||
fn new(
|
||||
input_resolution: (usize, usize),
|
||||
dim: usize,
|
||||
out: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
|
||||
let cfg2 = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
stride,
|
||||
groups: out,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
|
||||
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
|
||||
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
input_resolution,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for PatchMerging {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = if xs.rank() == 3 {
|
||||
let (h, w) = self.input_resolution;
|
||||
let b = xs.dim(0)?;
|
||||
xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
|
||||
} else {
|
||||
xs.clone()
|
||||
};
|
||||
xs.apply(&self.conv1)?
|
||||
.gelu()?
|
||||
.apply(&self.conv2)?
|
||||
.gelu()?
|
||||
.apply(&self.conv3)?
|
||||
.flatten_from(2)?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConvLayer {
|
||||
blocks: Vec<MBConv>,
|
||||
downsample: Option<PatchMerging>,
|
||||
}
|
||||
|
||||
impl ConvLayer {
|
||||
fn new(
|
||||
dim: usize,
|
||||
out: usize,
|
||||
input_resolution: (usize, usize),
|
||||
depth: usize,
|
||||
downsample: bool,
|
||||
conv_expand_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb_b = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
for index in 0..depth {
|
||||
let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let downsample = if downsample {
|
||||
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
|
||||
Some(downsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { blocks, downsample })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ConvLayer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
match &self.downsample {
|
||||
None => Ok(xs),
|
||||
Some(downsample) => downsample.forward(&xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Mlp {
|
||||
norm: candle_nn::LayerNorm,
|
||||
fc1: candle_nn::Linear,
|
||||
fc2: candle_nn::Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
|
||||
let fc1 = candle_nn::linear(in_, hidden, vb.pp("fc1"))?;
|
||||
let fc2 = candle_nn::linear(hidden, in_, vb.pp("fc2"))?;
|
||||
Ok(Self { norm, fc1, fc2 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.norm)?
|
||||
.apply(&self.fc1)?
|
||||
.gelu()?
|
||||
.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
norm: candle_nn::LayerNorm,
|
||||
qkv: candle_nn::Linear,
|
||||
proj: candle_nn::Linear,
|
||||
attention_biases: Tensor,
|
||||
ab: Tensor,
|
||||
key_dim: usize,
|
||||
num_heads: usize,
|
||||
d: usize,
|
||||
dh: usize,
|
||||
scale: f64,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
dim: usize,
|
||||
key_dim: usize,
|
||||
num_heads: usize,
|
||||
attn_ratio: usize,
|
||||
resolution: (usize, usize),
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let d = attn_ratio * key_dim;
|
||||
let dh = d * num_heads;
|
||||
let nh_kd = key_dim * num_heads;
|
||||
let h = dh + nh_kd * 2;
|
||||
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||
let qkv = candle_nn::linear(dim, h, vb.pp("qkv"))?;
|
||||
let proj = candle_nn::linear(dh, dim, vb.pp("proj"))?;
|
||||
|
||||
let points = (0..resolution.0)
|
||||
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
|
||||
.collect::<Vec<_>>();
|
||||
let mut idxs = Vec::with_capacity(points.len() * points.len());
|
||||
let mut attention_offsets = std::collections::HashMap::new();
|
||||
for &(x1, y1) in points.iter() {
|
||||
for &(x2, y2) in points.iter() {
|
||||
let offset = ((x2 - x1).abs(), (y2 - y1).abs());
|
||||
let l = attention_offsets.len();
|
||||
let idx = attention_offsets.entry(offset).or_insert(l);
|
||||
idxs.push(*idx as u32)
|
||||
}
|
||||
}
|
||||
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
|
||||
let idxs = Tensor::new(idxs, attention_biases.device())?;
|
||||
let ab = attention_biases.index_select(&idxs, 1)?;
|
||||
Ok(Self {
|
||||
norm,
|
||||
qkv,
|
||||
proj,
|
||||
attention_biases,
|
||||
ab,
|
||||
key_dim,
|
||||
num_heads,
|
||||
d,
|
||||
dh,
|
||||
scale: 1f64 / (key_dim as f64).sqrt(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, _) = xs.dims3()?;
|
||||
let xs = xs.apply(&self.norm)?;
|
||||
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
|
||||
let q = qkv
|
||||
.narrow(D::Minus1, 0, self.key_dim)?
|
||||
.permute((0, 2, 1, 3))?;
|
||||
let k = qkv
|
||||
.narrow(D::Minus1, self.key_dim, self.key_dim)?
|
||||
.permute((0, 2, 1, 3))?;
|
||||
let v = qkv
|
||||
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
|
||||
.permute((0, 2, 1, 3))?;
|
||||
let attn = (q.matmul(&k.t()?)? * self.scale)?;
|
||||
let attn = (attn + &self.ab)?;
|
||||
let attn = candle_nn::ops::softmax_last_dim(&attn)?;
|
||||
attn.matmul(&v)?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b, n, self.dh))?
|
||||
.apply(&self.proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TinyViTBlock {
|
||||
attn: Attention,
|
||||
local_conv: Conv2dBN,
|
||||
mlp: Mlp,
|
||||
window_size: usize,
|
||||
input_resolution: (usize, usize),
|
||||
}
|
||||
|
||||
impl TinyViTBlock {
|
||||
fn new(
|
||||
dim: usize,
|
||||
input_resolution: (usize, usize),
|
||||
num_heads: usize,
|
||||
window_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let head_dim = dim / num_heads;
|
||||
let attn = Attention::new(
|
||||
dim,
|
||||
head_dim,
|
||||
num_heads,
|
||||
1,
|
||||
(window_size, window_size),
|
||||
vb.pp("attn"),
|
||||
)?;
|
||||
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: LOCAL_CONV_SIZE / 2,
|
||||
..Default::default()
|
||||
};
|
||||
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
local_conv,
|
||||
mlp,
|
||||
window_size,
|
||||
input_resolution,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TinyViTBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (h, w) = self.input_resolution;
|
||||
let (b, l, c) = xs.dims3()?;
|
||||
let res_x = xs;
|
||||
let xs = if h == self.window_size && w == self.window_size {
|
||||
self.attn.forward(xs)?
|
||||
} else {
|
||||
let xs = xs.reshape((b, h, w, c))?;
|
||||
let pad_b = (self.window_size - h % self.window_size) % self.window_size;
|
||||
let pad_r = (self.window_size - w % self.window_size) % self.window_size;
|
||||
|
||||
let xs = if pad_b > 0 {
|
||||
xs.pad_with_zeros(D::Minus2, 0, pad_b)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = if pad_r > 0 {
|
||||
xs.pad_with_zeros(D::Minus1, 0, pad_r)?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let (p_h, p_w) = (h + pad_b, w + pad_r);
|
||||
let n_h = p_h / self.window_size;
|
||||
let n_w = p_w / self.window_size;
|
||||
let xs = xs
|
||||
.reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
|
||||
.transpose(2, 3)?
|
||||
.reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
|
||||
let xs = self.attn.forward(&xs)?;
|
||||
let xs = xs
|
||||
.reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
|
||||
.transpose(2, 3)?
|
||||
.reshape((b, p_h, p_w, c))?;
|
||||
let xs = if pad_r > 0 {
|
||||
xs.i((.., .., ..w))?.contiguous()?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
let xs = if pad_b > 0 {
|
||||
xs.i((.., ..h, ..))?.contiguous()?
|
||||
} else {
|
||||
xs
|
||||
};
|
||||
xs.reshape((b, l, c))?
|
||||
};
|
||||
let xs = (xs + res_x)?;
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.reshape((b, c, h, w))?
|
||||
.apply(&self.local_conv)?
|
||||
.reshape((b, c, l))?
|
||||
.transpose(1, 2)?;
|
||||
&xs + self.mlp.forward(&xs)?
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct BasicLayer {
|
||||
blocks: Vec<TinyViTBlock>,
|
||||
downsample: Option<PatchMerging>,
|
||||
}
|
||||
|
||||
impl BasicLayer {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
dim: usize,
|
||||
input_resolution: (usize, usize),
|
||||
depth: usize,
|
||||
num_heads: usize,
|
||||
window_size: usize,
|
||||
downsample: bool,
|
||||
out: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb_b = vb.pp("blocks");
|
||||
let mut blocks = Vec::with_capacity(depth);
|
||||
for index in 0..depth {
|
||||
let block = TinyViTBlock::new(
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size,
|
||||
vb_b.pp(index),
|
||||
)?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let downsample = if downsample {
|
||||
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
|
||||
Some(downsample)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self { blocks, downsample })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for BasicLayer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for block in self.blocks.iter() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
match &self.downsample {
|
||||
None => Ok(xs),
|
||||
Some(downsample) => downsample.forward(&xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TinyViT {
|
||||
patch_embed: PatchEmbed,
|
||||
layer0: ConvLayer,
|
||||
layers: Vec<BasicLayer>,
|
||||
norm_head: candle_nn::LayerNorm,
|
||||
head: candle_nn::Linear,
|
||||
neck_conv1: candle_nn::Conv2d,
|
||||
neck_ln1: crate::LayerNorm2d,
|
||||
neck_conv2: candle_nn::Conv2d,
|
||||
neck_ln2: crate::LayerNorm2d,
|
||||
}
|
||||
|
||||
impl TinyViT {
|
||||
pub fn new(
|
||||
embed_dims: &[usize],
|
||||
depths: &[usize],
|
||||
num_heads: &[usize],
|
||||
window_sizes: &[usize],
|
||||
num_classes: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
|
||||
let patches_resolution = IMG_SIZE / 4;
|
||||
|
||||
let vb_l = vb.pp("layers");
|
||||
let layer0 = ConvLayer::new(
|
||||
/* dim */ embed_dims[0],
|
||||
/* out */ embed_dims[1],
|
||||
/* input_resolution */ (patches_resolution, patches_resolution),
|
||||
/* depth */ depths[0],
|
||||
/* downsample */ true,
|
||||
/* conv_expand_ratio */ MBCONV_EXPAND_RATIO,
|
||||
vb_l.pp(0),
|
||||
)?;
|
||||
|
||||
let num_layers = embed_dims.len();
|
||||
let mut layers = Vec::with_capacity(num_layers - 1);
|
||||
for i_layer in 1..num_layers {
|
||||
let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
|
||||
let layer = BasicLayer::new(
|
||||
/* dim */ embed_dims[i_layer],
|
||||
/* input_resolution */ (patches_resolution, patches_resolution),
|
||||
/* depth */ depths[i_layer],
|
||||
/* num_heads */ num_heads[i_layer],
|
||||
/* window_size */ window_sizes[i_layer],
|
||||
/* downsample */ i_layer < num_layers - 1,
|
||||
/* out */ embed_dims[usize::min(i_layer + 1, num_layers - 1)],
|
||||
vb_l.pp(i_layer),
|
||||
)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
|
||||
let last_embed_dim = embed_dims[embed_dims.len() - 1];
|
||||
let norm_head = candle_nn::layer_norm(last_embed_dim, 1e-5, vb.pp("norm_head"))?;
|
||||
let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
|
||||
let neck_conv1 =
|
||||
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
|
||||
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
|
||||
let cfg = candle_nn::Conv2dConfig {
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
|
||||
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
|
||||
|
||||
Ok(Self {
|
||||
patch_embed,
|
||||
layer0,
|
||||
layers,
|
||||
norm_head,
|
||||
head,
|
||||
neck_conv1,
|
||||
neck_ln1,
|
||||
neck_conv2,
|
||||
neck_ln2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TinyViT {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = self.patch_embed.forward(xs)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs)?
|
||||
}
|
||||
let (b, _, c) = xs.dims3()?;
|
||||
xs.reshape((b, 64, 64, c))?
|
||||
.permute((0, 3, 1, 2))?
|
||||
.apply(&self.neck_conv1)?
|
||||
.apply(&self.neck_ln1)?
|
||||
.apply(&self.neck_conv2)?
|
||||
.apply(&self.neck_ln2)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tiny_vit_5m_224(vb: VarBuilder) -> Result<TinyViT> {
|
||||
TinyViT::new(
|
||||
/* embed_dims */ &[64, 128, 160, 320],
|
||||
/* depths */ &[2, 2, 6, 2],
|
||||
/* num_heads */ &[2, 4, 5, 10],
|
||||
/* window_sizes */ &[7, 7, 14, 7],
|
||||
/* num_classes */ 1000,
|
||||
vb,
|
||||
)
|
||||
}
|
@ -187,10 +187,10 @@ pub fn conv1d(
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv1dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(out_channels, in_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init_ws,
|
||||
@ -200,7 +200,7 @@ pub fn conv1d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
@ -209,10 +209,10 @@ pub fn conv2d(
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv2dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(
|
||||
out_channels,
|
||||
in_channels / cfg.groups,
|
||||
@ -227,7 +227,7 @@ pub fn conv2d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv2d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
@ -236,10 +236,10 @@ pub fn conv2d_no_bias(
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv2dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(
|
||||
out_channels,
|
||||
in_channels / cfg.groups,
|
||||
@ -257,19 +257,19 @@ pub fn conv_transpose2d(
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: ConvTranspose2dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<ConvTranspose2d> {
|
||||
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
|
||||
let init = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(in_channels, out_channels, kernel_size, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
)?;
|
||||
let bs = vs.get_with_hints(out_channels, "bias", init)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init)?;
|
||||
Ok(ConvTranspose2d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
@ -278,14 +278,14 @@ pub fn conv_transpose2d_no_bias(
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: ConvTranspose2dConfig,
|
||||
vs: crate::VarBuilder,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<ConvTranspose2d> {
|
||||
let bound = 1. / (out_channels as f64).sqrt() / kernel_size as f64;
|
||||
let init = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vs.get_with_hints(
|
||||
let ws = vb.get_with_hints(
|
||||
(out_channels, in_channels, kernel_size, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
|
Reference in New Issue
Block a user