mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Simplify the var-builder layer setup. (#133)
This commit is contained in:
@ -262,22 +262,16 @@ struct EncodecResnetBlock {
|
|||||||
impl EncodecResnetBlock {
|
impl EncodecResnetBlock {
|
||||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = dim / cfg.compress;
|
let h = dim / cfg.compress;
|
||||||
let mut layer = Layer::new("block");
|
let mut layer = Layer::new(vb.pp("block"));
|
||||||
if dilations.len() != 2 {
|
if dilations.len() != 2 {
|
||||||
anyhow::bail!("expected dilations of size 2")
|
anyhow::bail!("expected dilations of size 2")
|
||||||
}
|
}
|
||||||
// TODO: Apply dilations!
|
// TODO: Apply dilations!
|
||||||
layer.inc();
|
layer.inc();
|
||||||
let block_conv1 = EncodecConv1d::load(
|
let block_conv1 =
|
||||||
dim,
|
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
||||||
h,
|
|
||||||
cfg.residual_kernel_size,
|
|
||||||
1,
|
|
||||||
vb.pp(&layer.next_name()),
|
|
||||||
cfg,
|
|
||||||
)?;
|
|
||||||
layer.inc();
|
layer.inc();
|
||||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, vb.pp(&layer.next_name()), cfg)?;
|
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
||||||
let shortcut = if cfg.use_conv_shortcut {
|
let shortcut = if cfg.use_conv_shortcut {
|
||||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||||
Some(conv)
|
Some(conv)
|
||||||
@ -292,28 +286,24 @@ impl EncodecResnetBlock {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
struct Layer<'a> {
|
||||||
struct Layer {
|
vb: VarBuilder<'a>,
|
||||||
prefix: String,
|
|
||||||
cnt: usize,
|
cnt: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Layer {
|
impl<'a> Layer<'a> {
|
||||||
fn new(prefix: &str) -> Self {
|
fn new(vb: VarBuilder<'a>) -> Self {
|
||||||
Self {
|
Self { vb, cnt: 0 }
|
||||||
prefix: prefix.to_string(),
|
|
||||||
cnt: 0,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inc(&mut self) {
|
fn inc(&mut self) {
|
||||||
self.cnt += 1;
|
self.cnt += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn next_name(&mut self) -> String {
|
fn next(&mut self) -> VarBuilder<'a> {
|
||||||
let name = format!("{}.{}", self.prefix, self.cnt);
|
let vb = self.vb.pp(&self.cnt.to_string());
|
||||||
self.cnt += 1;
|
self.cnt += 1;
|
||||||
name
|
vb
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,13 +317,13 @@ struct EncodecEncoder {
|
|||||||
|
|
||||||
impl EncodecEncoder {
|
impl EncodecEncoder {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let mut layer = Layer::new("layers");
|
let mut layer = Layer::new(vb.pp("layers"));
|
||||||
let init_conv = EncodecConv1d::load(
|
let init_conv = EncodecConv1d::load(
|
||||||
cfg.audio_channels,
|
cfg.audio_channels,
|
||||||
cfg.num_filters,
|
cfg.num_filters,
|
||||||
cfg.kernel_size,
|
cfg.kernel_size,
|
||||||
1,
|
1,
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
let mut sampling_layers = vec![];
|
let mut sampling_layers = vec![];
|
||||||
@ -345,7 +335,7 @@ impl EncodecEncoder {
|
|||||||
let resnet = EncodecResnetBlock::load(
|
let resnet = EncodecResnetBlock::load(
|
||||||
current_scale,
|
current_scale,
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
resnets.push(resnet)
|
resnets.push(resnet)
|
||||||
@ -356,21 +346,20 @@ impl EncodecEncoder {
|
|||||||
current_scale * 2,
|
current_scale * 2,
|
||||||
ratio * 2,
|
ratio * 2,
|
||||||
ratio,
|
ratio,
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
sampling_layers.push((resnets, conv1d));
|
sampling_layers.push((resnets, conv1d));
|
||||||
scaling *= 2;
|
scaling *= 2;
|
||||||
}
|
}
|
||||||
let final_lstm =
|
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||||
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
|
||||||
layer.inc(); // ELU
|
layer.inc(); // ELU
|
||||||
let final_conv = EncodecConv1d::load(
|
let final_conv = EncodecConv1d::load(
|
||||||
cfg.num_filters * scaling,
|
cfg.num_filters * scaling,
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -392,18 +381,17 @@ struct EncodecDecoder {
|
|||||||
|
|
||||||
impl EncodecDecoder {
|
impl EncodecDecoder {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let mut layer = Layer::new("layers");
|
let mut layer = Layer::new(vb.pp("layers"));
|
||||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||||
let init_conv = EncodecConv1d::load(
|
let init_conv = EncodecConv1d::load(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.num_filters * scaling,
|
cfg.num_filters * scaling,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
let init_lstm =
|
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||||
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
|
||||||
let mut sampling_layers = vec![];
|
let mut sampling_layers = vec![];
|
||||||
for &ratio in cfg.upsampling_ratios.iter() {
|
for &ratio in cfg.upsampling_ratios.iter() {
|
||||||
let current_scale = scaling * cfg.num_filters;
|
let current_scale = scaling * cfg.num_filters;
|
||||||
@ -413,7 +401,7 @@ impl EncodecDecoder {
|
|||||||
current_scale / 2,
|
current_scale / 2,
|
||||||
ratio * 2,
|
ratio * 2,
|
||||||
ratio,
|
ratio,
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
let mut resnets = vec![];
|
let mut resnets = vec![];
|
||||||
@ -421,7 +409,7 @@ impl EncodecDecoder {
|
|||||||
let resnet = EncodecResnetBlock::load(
|
let resnet = EncodecResnetBlock::load(
|
||||||
current_scale / 2,
|
current_scale / 2,
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
resnets.push(resnet)
|
resnets.push(resnet)
|
||||||
@ -435,7 +423,7 @@ impl EncodecDecoder {
|
|||||||
cfg.audio_channels,
|
cfg.audio_channels,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
vb.pp(&layer.next_name()),
|
layer.next(),
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
Reference in New Issue
Block a user