mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
MusicGen var-store path cleanup. (#132)
This commit is contained in:
@ -127,12 +127,12 @@ struct EncodecEuclideanCodebook {
|
||||
}
|
||||
|
||||
impl EncodecEuclideanCodebook {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inited = vb.get(1, &format!("{p}.inited"))?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, &format!("{p}.cluster_size"))?;
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inited = vb.get(1, "inited")?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, &format!("{p}.embed"))?;
|
||||
let embed_avg = vb.get(e_shape, &format!("{p}.embed_avg"))?;
|
||||
let embed = vb.get(e_shape, "embed")?;
|
||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
@ -148,8 +148,8 @@ struct EncodecVectorQuantization {
|
||||
}
|
||||
|
||||
impl EncodecVectorQuantization {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let codebook = EncodecEuclideanCodebook::load(&format!("{p}.codebook"), vb, cfg)?;
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
}
|
||||
@ -160,10 +160,10 @@ struct EncodecResidualVectorQuantizer {
|
||||
}
|
||||
|
||||
impl EncodecResidualVectorQuantizer {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let p = format!("{p}.layers");
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("layers");
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| EncodecVectorQuantization::load(&format!("{p}.{i}"), vb, cfg))
|
||||
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
@ -176,14 +176,14 @@ struct EncodecLSTM {
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let p = format!("{p}.lstm");
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for i in 0..cfg.num_lstm_layers {
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("{p}.weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("{p}.weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("{p}.bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("{p}.bias_ih_l{i}"))?;
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||
}
|
||||
Ok(Self { layers })
|
||||
@ -203,14 +203,13 @@ impl EncodecConvTranspose1d {
|
||||
out_c: usize,
|
||||
k: usize,
|
||||
_stride: usize,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
vb: VarBuilder,
|
||||
_cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let p = format!("{p}.conv");
|
||||
let weight_g = vb.get((in_c, 1, 1), &format!("{p}.weight_g"))?;
|
||||
let weight_v = vb.get((in_c, out_c, k), &format!("{p}.weight_v"))?;
|
||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
||||
let vb = &vb.pp("conv");
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Self {
|
||||
weight_g,
|
||||
weight_v,
|
||||
@ -230,8 +229,7 @@ impl EncodecConv1d {
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
@ -240,16 +238,14 @@ impl EncodecConv1d {
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
&format!("{p}.conv"),
|
||||
vb,
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig { padding: 0, stride },
|
||||
&format!("{p}.conv"),
|
||||
vb,
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
Ok(Self { conv })
|
||||
@ -264,15 +260,9 @@ struct EncodecResnetBlock {
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
fn load(
|
||||
dim: usize,
|
||||
dilations: &[usize],
|
||||
p: &str,
|
||||
vb: &VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(format!("{p}.block"));
|
||||
let mut layer = Layer::new("block");
|
||||
if dilations.len() != 2 {
|
||||
anyhow::bail!("expected dilations of size 2")
|
||||
}
|
||||
@ -283,14 +273,13 @@ impl EncodecResnetBlock {
|
||||
h,
|
||||
cfg.residual_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, &layer.next_name(), vb, cfg)?;
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, vb.pp(&layer.next_name()), cfg)?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, &format!("{p}.shortcut"), vb, cfg)?;
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
@ -310,8 +299,11 @@ struct Layer {
|
||||
}
|
||||
|
||||
impl Layer {
|
||||
fn new(prefix: String) -> Self {
|
||||
Self { prefix, cnt: 0 }
|
||||
fn new(prefix: &str) -> Self {
|
||||
Self {
|
||||
prefix: prefix.to_string(),
|
||||
cnt: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
@ -334,15 +326,14 @@ struct EncodecEncoder {
|
||||
}
|
||||
|
||||
impl EncodecEncoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(format!("{p}.layers"));
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new("layers");
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
@ -354,8 +345,7 @@ impl EncodecEncoder {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
@ -366,22 +356,21 @@ impl EncodecEncoder {
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
|
||||
let final_lstm =
|
||||
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
@ -402,19 +391,19 @@ struct EncodecDecoder {
|
||||
}
|
||||
|
||||
impl EncodecDecoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(format!("{p}.layers"));
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new("layers");
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
|
||||
let init_lstm =
|
||||
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
@ -424,8 +413,7 @@ impl EncodecDecoder {
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
@ -433,8 +421,7 @@ impl EncodecDecoder {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
@ -448,8 +435,7 @@ impl EncodecDecoder {
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
&layer.next_name(),
|
||||
vb,
|
||||
vb.pp(&layer.next_name()),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
@ -469,10 +455,10 @@ pub struct EncodecModel {
|
||||
}
|
||||
|
||||
impl EncodecModel {
|
||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = EncodecEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
let decoder = EncodecDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
|
||||
let quantizer = EncodecResidualVectorQuantizer::load(&format!("{p}.quantizer"), vb, cfg)?;
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
||||
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
||||
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
|
Reference in New Issue
Block a user