mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Simplify the var-builder layer setup. (#133)
This commit is contained in:
@ -262,22 +262,16 @@ struct EncodecResnetBlock {
|
||||
impl EncodecResnetBlock {
|
||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new("block");
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
if dilations.len() != 2 {
|
||||
anyhow::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 = EncodecConv1d::load(
|
||||
dim,
|
||||
h,
|
||||
cfg.residual_kernel_size,
|
||||
1,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
let block_conv1 =
|
||||
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, vb.pp(&layer.next_name()), cfg)?;
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||
Some(conv)
|
||||
@ -292,28 +286,24 @@ impl EncodecResnetBlock {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Layer {
|
||||
prefix: String,
|
||||
struct Layer<'a> {
|
||||
vb: VarBuilder<'a>,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl Layer {
|
||||
fn new(prefix: &str) -> Self {
|
||||
Self {
|
||||
prefix: prefix.to_string(),
|
||||
cnt: 0,
|
||||
}
|
||||
impl<'a> Layer<'a> {
|
||||
fn new(vb: VarBuilder<'a>) -> Self {
|
||||
Self { vb, cnt: 0 }
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next_name(&mut self) -> String {
|
||||
let name = format!("{}.{}", self.prefix, self.cnt);
|
||||
fn next(&mut self) -> VarBuilder<'a> {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
name
|
||||
vb
|
||||
}
|
||||
}
|
||||
|
||||
@ -327,13 +317,13 @@ struct EncodecEncoder {
|
||||
|
||||
impl EncodecEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new("layers");
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
@ -345,7 +335,7 @@ impl EncodecEncoder {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
@ -356,21 +346,20 @@ impl EncodecEncoder {
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm =
|
||||
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
@ -392,18 +381,17 @@ struct EncodecDecoder {
|
||||
|
||||
impl EncodecDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new("layers");
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let init_lstm =
|
||||
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
@ -413,7 +401,7 @@ impl EncodecDecoder {
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
@ -421,7 +409,7 @@ impl EncodecDecoder {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
@ -435,7 +423,7 @@ impl EncodecDecoder {
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
vb.pp(&layer.next_name()),
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
|
Reference in New Issue
Block a user