mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
MusicGen var-store path cleanup. (#132)
This commit is contained in:
@ -127,12 +127,12 @@ struct EncodecEuclideanCodebook {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecEuclideanCodebook {
|
impl EncodecEuclideanCodebook {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let inited = vb.get(1, &format!("{p}.inited"))?;
|
let inited = vb.get(1, "inited")?;
|
||||||
let cluster_size = vb.get(cfg.codebook_size, &format!("{p}.cluster_size"))?;
|
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||||
let embed = vb.get(e_shape, &format!("{p}.embed"))?;
|
let embed = vb.get(e_shape, "embed")?;
|
||||||
let embed_avg = vb.get(e_shape, &format!("{p}.embed_avg"))?;
|
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inited,
|
inited,
|
||||||
cluster_size,
|
cluster_size,
|
||||||
@ -148,8 +148,8 @@ struct EncodecVectorQuantization {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecVectorQuantization {
|
impl EncodecVectorQuantization {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let codebook = EncodecEuclideanCodebook::load(&format!("{p}.codebook"), vb, cfg)?;
|
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||||
Ok(Self { codebook })
|
Ok(Self { codebook })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -160,10 +160,10 @@ struct EncodecResidualVectorQuantizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecResidualVectorQuantizer {
|
impl EncodecResidualVectorQuantizer {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let p = format!("{p}.layers");
|
let vb = &vb.pp("layers");
|
||||||
let layers = (0..cfg.num_quantizers())
|
let layers = (0..cfg.num_quantizers())
|
||||||
.map(|i| EncodecVectorQuantization::load(&format!("{p}.{i}"), vb, cfg))
|
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self { layers })
|
Ok(Self { layers })
|
||||||
}
|
}
|
||||||
@ -176,14 +176,14 @@ struct EncodecLSTM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecLSTM {
|
impl EncodecLSTM {
|
||||||
fn load(dim: usize, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let p = format!("{p}.lstm");
|
let vb = &vb.pp("lstm");
|
||||||
let mut layers = vec![];
|
let mut layers = vec![];
|
||||||
for i in 0..cfg.num_lstm_layers {
|
for i in 0..cfg.num_lstm_layers {
|
||||||
let w_hh = vb.get((4 * dim, dim), &format!("{p}.weight_hh_l{i}"))?;
|
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
||||||
let w_ih = vb.get((4 * dim, dim), &format!("{p}.weight_ih_l{i}"))?;
|
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
||||||
let b_hh = vb.get(4 * dim, &format!("{p}.bias_hh_l{i}"))?;
|
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
||||||
let b_ih = vb.get(4 * dim, &format!("{p}.bias_ih_l{i}"))?;
|
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
||||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||||
}
|
}
|
||||||
Ok(Self { layers })
|
Ok(Self { layers })
|
||||||
@ -203,14 +203,13 @@ impl EncodecConvTranspose1d {
|
|||||||
out_c: usize,
|
out_c: usize,
|
||||||
k: usize,
|
k: usize,
|
||||||
_stride: usize,
|
_stride: usize,
|
||||||
p: &str,
|
vb: VarBuilder,
|
||||||
vb: &VarBuilder,
|
|
||||||
_cfg: &Config,
|
_cfg: &Config,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let p = format!("{p}.conv");
|
let vb = &vb.pp("conv");
|
||||||
let weight_g = vb.get((in_c, 1, 1), &format!("{p}.weight_g"))?;
|
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||||
let weight_v = vb.get((in_c, out_c, k), &format!("{p}.weight_v"))?;
|
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
||||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
let bias = vb.get(out_c, "bias")?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
weight_g,
|
weight_g,
|
||||||
weight_v,
|
weight_v,
|
||||||
@ -230,8 +229,7 @@ impl EncodecConv1d {
|
|||||||
out_c: usize,
|
out_c: usize,
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
stride: usize,
|
stride: usize,
|
||||||
p: &str,
|
vb: VarBuilder,
|
||||||
vb: &VarBuilder,
|
|
||||||
cfg: &Config,
|
cfg: &Config,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let conv = match cfg.norm_type {
|
let conv = match cfg.norm_type {
|
||||||
@ -240,16 +238,14 @@ impl EncodecConv1d {
|
|||||||
out_c,
|
out_c,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
Conv1dConfig { padding: 0, stride },
|
Conv1dConfig { padding: 0, stride },
|
||||||
&format!("{p}.conv"),
|
vb.pp("conv"),
|
||||||
vb,
|
|
||||||
)?,
|
)?,
|
||||||
NormType::None => conv1d(
|
NormType::None => conv1d(
|
||||||
in_c,
|
in_c,
|
||||||
out_c,
|
out_c,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
Conv1dConfig { padding: 0, stride },
|
Conv1dConfig { padding: 0, stride },
|
||||||
&format!("{p}.conv"),
|
vb.pp("conv"),
|
||||||
vb,
|
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
Ok(Self { conv })
|
Ok(Self { conv })
|
||||||
@ -264,15 +260,9 @@ struct EncodecResnetBlock {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecResnetBlock {
|
impl EncodecResnetBlock {
|
||||||
fn load(
|
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
dim: usize,
|
|
||||||
dilations: &[usize],
|
|
||||||
p: &str,
|
|
||||||
vb: &VarBuilder,
|
|
||||||
cfg: &Config,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let h = dim / cfg.compress;
|
let h = dim / cfg.compress;
|
||||||
let mut layer = Layer::new(format!("{p}.block"));
|
let mut layer = Layer::new("block");
|
||||||
if dilations.len() != 2 {
|
if dilations.len() != 2 {
|
||||||
anyhow::bail!("expected dilations of size 2")
|
anyhow::bail!("expected dilations of size 2")
|
||||||
}
|
}
|
||||||
@ -283,14 +273,13 @@ impl EncodecResnetBlock {
|
|||||||
h,
|
h,
|
||||||
cfg.residual_kernel_size,
|
cfg.residual_kernel_size,
|
||||||
1,
|
1,
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
layer.inc();
|
layer.inc();
|
||||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, &layer.next_name(), vb, cfg)?;
|
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, vb.pp(&layer.next_name()), cfg)?;
|
||||||
let shortcut = if cfg.use_conv_shortcut {
|
let shortcut = if cfg.use_conv_shortcut {
|
||||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, &format!("{p}.shortcut"), vb, cfg)?;
|
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||||
Some(conv)
|
Some(conv)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -310,8 +299,11 @@ struct Layer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Layer {
|
impl Layer {
|
||||||
fn new(prefix: String) -> Self {
|
fn new(prefix: &str) -> Self {
|
||||||
Self { prefix, cnt: 0 }
|
Self {
|
||||||
|
prefix: prefix.to_string(),
|
||||||
|
cnt: 0,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inc(&mut self) {
|
fn inc(&mut self) {
|
||||||
@ -334,15 +326,14 @@ struct EncodecEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecEncoder {
|
impl EncodecEncoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let mut layer = Layer::new(format!("{p}.layers"));
|
let mut layer = Layer::new("layers");
|
||||||
let init_conv = EncodecConv1d::load(
|
let init_conv = EncodecConv1d::load(
|
||||||
cfg.audio_channels,
|
cfg.audio_channels,
|
||||||
cfg.num_filters,
|
cfg.num_filters,
|
||||||
cfg.kernel_size,
|
cfg.kernel_size,
|
||||||
1,
|
1,
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
let mut sampling_layers = vec![];
|
let mut sampling_layers = vec![];
|
||||||
@ -354,8 +345,7 @@ impl EncodecEncoder {
|
|||||||
let resnet = EncodecResnetBlock::load(
|
let resnet = EncodecResnetBlock::load(
|
||||||
current_scale,
|
current_scale,
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
resnets.push(resnet)
|
resnets.push(resnet)
|
||||||
@ -366,22 +356,21 @@ impl EncodecEncoder {
|
|||||||
current_scale * 2,
|
current_scale * 2,
|
||||||
ratio * 2,
|
ratio * 2,
|
||||||
ratio,
|
ratio,
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
sampling_layers.push((resnets, conv1d));
|
sampling_layers.push((resnets, conv1d));
|
||||||
scaling *= 2;
|
scaling *= 2;
|
||||||
}
|
}
|
||||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
|
let final_lstm =
|
||||||
|
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
||||||
layer.inc(); // ELU
|
layer.inc(); // ELU
|
||||||
let final_conv = EncodecConv1d::load(
|
let final_conv = EncodecConv1d::load(
|
||||||
cfg.num_filters * scaling,
|
cfg.num_filters * scaling,
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -402,19 +391,19 @@ struct EncodecDecoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecDecoder {
|
impl EncodecDecoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let mut layer = Layer::new(format!("{p}.layers"));
|
let mut layer = Layer::new("layers");
|
||||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||||
let init_conv = EncodecConv1d::load(
|
let init_conv = EncodecConv1d::load(
|
||||||
cfg.hidden_size,
|
cfg.hidden_size,
|
||||||
cfg.num_filters * scaling,
|
cfg.num_filters * scaling,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, &layer.next_name(), vb, cfg)?;
|
let init_lstm =
|
||||||
|
EncodecLSTM::load(cfg.num_filters * scaling, vb.pp(&layer.next_name()), cfg)?;
|
||||||
let mut sampling_layers = vec![];
|
let mut sampling_layers = vec![];
|
||||||
for &ratio in cfg.upsampling_ratios.iter() {
|
for &ratio in cfg.upsampling_ratios.iter() {
|
||||||
let current_scale = scaling * cfg.num_filters;
|
let current_scale = scaling * cfg.num_filters;
|
||||||
@ -424,8 +413,7 @@ impl EncodecDecoder {
|
|||||||
current_scale / 2,
|
current_scale / 2,
|
||||||
ratio * 2,
|
ratio * 2,
|
||||||
ratio,
|
ratio,
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
let mut resnets = vec![];
|
let mut resnets = vec![];
|
||||||
@ -433,8 +421,7 @@ impl EncodecDecoder {
|
|||||||
let resnet = EncodecResnetBlock::load(
|
let resnet = EncodecResnetBlock::load(
|
||||||
current_scale / 2,
|
current_scale / 2,
|
||||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
resnets.push(resnet)
|
resnets.push(resnet)
|
||||||
@ -448,8 +435,7 @@ impl EncodecDecoder {
|
|||||||
cfg.audio_channels,
|
cfg.audio_channels,
|
||||||
cfg.last_kernel_size,
|
cfg.last_kernel_size,
|
||||||
1,
|
1,
|
||||||
&layer.next_name(),
|
vb.pp(&layer.next_name()),
|
||||||
vb,
|
|
||||||
cfg,
|
cfg,
|
||||||
)?;
|
)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -469,10 +455,10 @@ pub struct EncodecModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EncodecModel {
|
impl EncodecModel {
|
||||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let encoder = EncodecEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
|
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
||||||
let decoder = EncodecDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
|
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
||||||
let quantizer = EncodecResidualVectorQuantizer::load(&format!("{p}.quantizer"), vb, cfg)?;
|
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
encoder,
|
encoder,
|
||||||
decoder,
|
decoder,
|
||||||
|
@ -54,6 +54,6 @@ fn main() -> Result<()> {
|
|||||||
let model = model.deserialize()?;
|
let model = model.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||||
let config = GenConfig::small();
|
let config = GenConfig::small();
|
||||||
let _model = MusicgenForConditionalGeneration::load(&vb, config)?;
|
let _model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -111,7 +111,7 @@ struct MusicgenSinusoidalPositionalEmbedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MusicgenSinusoidalPositionalEmbedding {
|
impl MusicgenSinusoidalPositionalEmbedding {
|
||||||
fn load(_vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(_vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let num_positions = cfg.max_position_embeddings;
|
let num_positions = cfg.max_position_embeddings;
|
||||||
let embedding_dim = cfg.hidden_size;
|
let embedding_dim = cfg.hidden_size;
|
||||||
let weights = get_embedding(num_positions, embedding_dim)?;
|
let weights = get_embedding(num_positions, embedding_dim)?;
|
||||||
@ -144,14 +144,14 @@ struct MusicgenAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MusicgenAttention {
|
impl MusicgenAttention {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let head_dim = h / num_heads;
|
let head_dim = h / num_heads;
|
||||||
let k_proj = linear(h, h, false, &format!("{p}.k_proj"), vb)?;
|
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
|
||||||
let v_proj = linear(h, h, false, &format!("{p}.v_proj"), vb)?;
|
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
|
||||||
let q_proj = linear(h, h, false, &format!("{p}.q_proj"), vb)?;
|
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
|
||||||
let out_proj = linear(h, h, false, &format!("{p}.out_proj"), vb)?;
|
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
scaling: 1. / (head_dim as f64).sqrt(),
|
scaling: 1. / (head_dim as f64).sqrt(),
|
||||||
is_decoder: true,
|
is_decoder: true,
|
||||||
@ -212,16 +212,15 @@ struct MusicgenDecoderLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MusicgenDecoderLayer {
|
impl MusicgenDecoderLayer {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let self_attn = MusicgenAttention::load(&format!("{p}.self_attn"), vb, cfg)?;
|
let self_attn = MusicgenAttention::load(vb.pp("self_attn"), cfg)?;
|
||||||
let self_attn_layer_norm = layer_norm(h, 1e-5, &format!("{p}.self_attn_layer_norm"), vb)?;
|
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||||
let encoder_attn = MusicgenAttention::load(&format!("{p}.encoder_attn"), vb, cfg)?;
|
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
|
||||||
let encoder_attn_layer_norm =
|
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
layer_norm(h, 1e-5, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
|
||||||
let fc1 = linear(h, cfg.ffn_dim, false, &format!("{p}.fc1"), vb)?;
|
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
|
||||||
let fc2 = linear(cfg.ffn_dim, h, false, &format!("{p}.fc2"), vb)?;
|
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
|
||||||
let final_layer_norm = layer_norm(h, 1e-5, &format!("{p}.final_layer_norm"), vb)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
self_attn_layer_norm,
|
self_attn_layer_norm,
|
||||||
@ -276,7 +275,7 @@ struct MusicgenDecoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MusicgenDecoder {
|
impl MusicgenDecoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let embed_scale = if cfg.scale_embedding {
|
let embed_scale = if cfg.scale_embedding {
|
||||||
(h as f64).sqrt()
|
(h as f64).sqrt()
|
||||||
@ -285,13 +284,13 @@ impl MusicgenDecoder {
|
|||||||
};
|
};
|
||||||
let embed_dim = cfg.vocab_size + 1;
|
let embed_dim = cfg.vocab_size + 1;
|
||||||
let embed_tokens = (0..cfg.num_codebooks)
|
let embed_tokens = (0..cfg.num_codebooks)
|
||||||
.map(|i| embedding(embed_dim, h, &format!("{p}.embed_tokens.{i}"), vb))
|
.map(|i| embedding(embed_dim, h, vb.pp(&format!("embed_tokens.{i}"))))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb, cfg)?;
|
let embed_positions = MusicgenSinusoidalPositionalEmbedding::load(vb.clone(), cfg)?;
|
||||||
let layers = (0..cfg.num_hidden_layers)
|
let layers = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| MusicgenDecoderLayer::load(&format!("{p}.layers.{i}"), vb, cfg))
|
.map(|i| MusicgenDecoderLayer::load(vb.pp(&format!("layers.{i}")), cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let layer_norm = layer_norm(h, 1e-5, &format!("{p}.layer_norm"), vb)?;
|
let layer_norm = layer_norm(h, 1e-5, vb.pp("layer_norm"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
embed_tokens,
|
embed_tokens,
|
||||||
embed_positions,
|
embed_positions,
|
||||||
@ -338,11 +337,11 @@ pub struct MusicgenForCausalLM {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MusicgenForCausalLM {
|
impl MusicgenForCausalLM {
|
||||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let decoder = MusicgenDecoder::load(&format!("{p}.model.decoder"), vb, cfg)?;
|
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
||||||
let lm_heads = (0..cfg.num_codebooks)
|
let lm_heads = (0..cfg.num_codebooks)
|
||||||
.map(|i| linear(h, cfg.vocab_size, false, &format!("{p}.lm_heads.{i}"), vb))
|
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
decoder,
|
decoder,
|
||||||
@ -399,10 +398,11 @@ impl MusicgenForConditionalGeneration {
|
|||||||
&self.cfg
|
&self.cfg
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: &VarBuilder, cfg: GenConfig) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||||
let text_encoder = t5_model::T5EncoderModel::load("text_encoder", vb, &cfg.t5)?;
|
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||||
let audio_encoder = encodec_model::EncodecModel::load("audio_encoder", vb, &cfg.encodec)?;
|
let audio_encoder =
|
||||||
let decoder = MusicgenForCausalLM::load("decoder", vb, &cfg.musicgen)?;
|
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||||
|
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
text_encoder,
|
text_encoder,
|
||||||
audio_encoder,
|
audio_encoder,
|
||||||
|
@ -8,10 +8,10 @@ const MAX_SEQ_LEN: usize = 5000;
|
|||||||
pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>;
|
pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>;
|
||||||
pub type Linear = candle_nn::Linear;
|
pub type Linear = candle_nn::Linear;
|
||||||
|
|
||||||
pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder) -> Result<Linear> {
|
pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), &format!("{p}.weight"))?;
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
let bias = if bias {
|
let bias = if bias {
|
||||||
Some(vb.get(size2, &format!("{p}.bias"))?)
|
Some(vb.get(size2, "bias")?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@ -20,17 +20,11 @@ pub fn linear(size1: usize, size2: usize, bias: bool, p: &str, vb: &VarBuilder)
|
|||||||
|
|
||||||
pub type LayerNorm = candle_nn::LayerNorm;
|
pub type LayerNorm = candle_nn::LayerNorm;
|
||||||
|
|
||||||
pub fn layer_norm(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<LayerNorm> {
|
pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let (weight, bias) = match (
|
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||||
vb.get(size, &format!("{p}.weight")),
|
|
||||||
vb.get(size, &format!("{p}.bias")),
|
|
||||||
) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
(Err(err), _) | (_, Err(err)) => {
|
||||||
if let (Ok(weight), Ok(bias)) = (
|
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||||
vb.get(size, &format!("{p}.gamma")),
|
|
||||||
vb.get(size, &format!("{p}.beta")),
|
|
||||||
) {
|
|
||||||
(weight, bias)
|
(weight, bias)
|
||||||
} else {
|
} else {
|
||||||
return Err(err.into());
|
return Err(err.into());
|
||||||
@ -58,13 +52,8 @@ impl Dropout {
|
|||||||
|
|
||||||
pub type Embedding = candle_nn::Embedding;
|
pub type Embedding = candle_nn::Embedding;
|
||||||
|
|
||||||
pub fn embedding(
|
pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||||
vocab_size: usize,
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
hidden_size: usize,
|
|
||||||
p: &str,
|
|
||||||
vb: &VarBuilder,
|
|
||||||
) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,14 +68,13 @@ pub fn conv1d_weight_norm(
|
|||||||
out_c: usize,
|
out_c: usize,
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
config: Conv1dConfig,
|
config: Conv1dConfig,
|
||||||
p: &str,
|
vb: VarBuilder,
|
||||||
vb: &VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
) -> Result<Conv1d> {
|
||||||
let weight_g = vb.get((out_c, 1, 1), &format!("{p}.weight_g"))?;
|
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||||
let weight_v = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight_v"))?;
|
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||||
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
|
let norm_v = (&weight_v * &weight_v)?.sum(&[1, 2])?.sqrt()?;
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
let bias = vb.get(out_c, "bias")?;
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,11 +83,10 @@ pub fn conv1d(
|
|||||||
out_c: usize,
|
out_c: usize,
|
||||||
kernel_size: usize,
|
kernel_size: usize,
|
||||||
config: Conv1dConfig,
|
config: Conv1dConfig,
|
||||||
p: &str,
|
vb: VarBuilder,
|
||||||
vb: &VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
) -> Result<Conv1d> {
|
||||||
let weight = vb.get((out_c, in_c, kernel_size), &format!("{p}.weight"))?;
|
let weight = vb.get((out_c, in_c, kernel_size), "weight")?;
|
||||||
let bias = vb.get(out_c, &format!("{p}.bias"))?;
|
let bias = vb.get(out_c, "bias")?;
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,8 +85,8 @@ struct T5LayerNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerNorm {
|
impl T5LayerNorm {
|
||||||
fn load(h: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||||
let weight = vb.get(h, &format!("{p}.weight"))?;
|
let weight = vb.get(h, "weight")?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
weight,
|
weight,
|
||||||
variance_epsilon: eps,
|
variance_epsilon: eps,
|
||||||
@ -103,9 +103,9 @@ struct T5DenseActDense {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5DenseActDense {
|
impl T5DenseActDense {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wi = linear(cfg.d_model, cfg.d_ff, false, &format!("{p}.wi"), vb)?;
|
let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
|
||||||
let wo = linear(cfg.d_ff, cfg.d_model, false, &format!("{p}.wo"), vb)?;
|
let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
let dropout = Dropout::new(cfg.dropout_rate);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wi,
|
wi,
|
||||||
@ -124,15 +124,11 @@ struct T5LayerFF {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerFF {
|
impl T5LayerFF {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
// is_gated_act is not supported.
|
// is_gated_act is not supported.
|
||||||
let dense_relu_dense = T5DenseActDense::load(&format!("{p}.DenseReluDense"), vb, cfg)?;
|
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
||||||
let layer_norm = T5LayerNorm::load(
|
let layer_norm =
|
||||||
cfg.d_model,
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
cfg.layer_norm_epsilon,
|
|
||||||
&format!("{p}.layer_norm"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
let dropout = Dropout::new(cfg.dropout_rate);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense_relu_dense,
|
dense_relu_dense,
|
||||||
@ -152,18 +148,17 @@ struct T5Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5Attention {
|
impl T5Attention {
|
||||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||||
let q = linear(cfg.d_model, inner_dim, false, &format!("{p}.q"), vb)?;
|
let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
|
||||||
let k = linear(cfg.d_model, inner_dim, false, &format!("{p}.k"), vb)?;
|
let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
|
||||||
let v = linear(cfg.d_model, inner_dim, false, &format!("{p}.v"), vb)?;
|
let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
|
||||||
let o = linear(inner_dim, cfg.d_model, false, &format!("{p}.o"), vb)?;
|
let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
|
||||||
let relative_attention_bias = if h {
|
let relative_attention_bias = if h {
|
||||||
let emb = embedding(
|
let emb = embedding(
|
||||||
cfg.relative_attention_num_buckets,
|
cfg.relative_attention_num_buckets,
|
||||||
cfg.num_heads,
|
cfg.num_heads,
|
||||||
&format!("{p}.relative_attention_bias"),
|
vb.pp("relative_attention_bias"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
Some(emb)
|
Some(emb)
|
||||||
} else {
|
} else {
|
||||||
@ -187,14 +182,10 @@ struct T5LayerSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerSelfAttention {
|
impl T5LayerSelfAttention {
|
||||||
fn load(h: bool, p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let self_attention = T5Attention::load(h, &format!("{p}.SelfAttention"), vb, cfg)?;
|
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||||
let layer_norm = T5LayerNorm::load(
|
let layer_norm =
|
||||||
cfg.d_model,
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
cfg.layer_norm_epsilon,
|
|
||||||
&format!("{p}.layer_norm"),
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
let dropout = Dropout::new(cfg.dropout_rate);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attention,
|
self_attention,
|
||||||
@ -208,7 +199,7 @@ impl T5LayerSelfAttention {
|
|||||||
struct T5LayerCrossAttention {}
|
struct T5LayerCrossAttention {}
|
||||||
|
|
||||||
impl T5LayerCrossAttention {
|
impl T5LayerCrossAttention {
|
||||||
fn load(_p: &str, _vb: &VarBuilder, _cfg: &Config) -> Result<Self> {
|
fn load(_vb: VarBuilder, _cfg: &Config) -> Result<Self> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -221,22 +212,16 @@ struct T5Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5Block {
|
impl T5Block {
|
||||||
fn load(
|
fn load(has_relative_attention_bias: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
has_relative_attention_bias: bool,
|
let vb = vb.pp("layer");
|
||||||
p: &str,
|
let self_attn = T5LayerSelfAttention::load(has_relative_attention_bias, vb.pp("0"), cfg)?;
|
||||||
vb: &VarBuilder,
|
|
||||||
cfg: &Config,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let p = &format!("{p}.layer");
|
|
||||||
let self_attn =
|
|
||||||
T5LayerSelfAttention::load(has_relative_attention_bias, &format!("{p}.0"), vb, cfg)?;
|
|
||||||
let cross_attn = if cfg.is_decoder {
|
let cross_attn = if cfg.is_decoder {
|
||||||
Some(T5LayerCrossAttention::load(&format!("{p}.1"), vb, cfg)?)
|
Some(T5LayerCrossAttention::load(vb.pp("1"), cfg)?)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
||||||
let ff = T5LayerFF::load(&format!("{p}.{ff_i}"), vb, cfg)?;
|
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
cross_attn,
|
cross_attn,
|
||||||
@ -254,15 +239,14 @@ struct T5Stack {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5Stack {
|
impl T5Stack {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let block = (0..cfg.num_layers)
|
let block = (0..cfg.num_layers)
|
||||||
.map(|i| T5Block::load(i == 0, &format!("{p}.block.{i}"), vb, cfg))
|
.map(|i| T5Block::load(i == 0, vb.pp(&format!("block.{i}")), cfg))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let final_layer_norm = T5LayerNorm::load(
|
let final_layer_norm = T5LayerNorm::load(
|
||||||
cfg.d_model,
|
cfg.d_model,
|
||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
&format!("{p}.final_layer_norm"),
|
vb.pp("final_layer_norm"),
|
||||||
vb,
|
|
||||||
)?;
|
)?;
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
let dropout = Dropout::new(cfg.dropout_rate);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -280,9 +264,9 @@ pub struct T5EncoderModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl T5EncoderModel {
|
impl T5EncoderModel {
|
||||||
pub fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let shared = embedding(cfg.vocab_size, cfg.d_model, &format!("{p}.shared"), vb)?;
|
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
let encoder = T5Stack::load(&format!("{p}.encoder"), vb, cfg)?;
|
let encoder = T5Stack::load(vb.pp("encoder"), cfg)?;
|
||||||
Ok(Self { shared, encoder })
|
Ok(Self { shared, encoder })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user