Extract T5 module and add main function to use it (#829)

* Extract t5 out of musicgen

* Add main for t5 module
This commit is contained in:
Juarez Bochi
2023-09-12 23:14:05 -07:00
committed by GitHub
parent e82fcf1c59
commit 9daa6dbe87
8 changed files with 184 additions and 21 deletions

View File

@ -1,9 +1,10 @@
use crate::{encodec_model, t5_model};
use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@ -370,7 +371,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
pub text_encoder: crate::t5_model::T5EncoderModel,
pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
@ -379,7 +380,7 @@ pub struct MusicgenForConditionalGeneration {
#[derive(Debug, Clone, PartialEq)]
pub struct GenConfig {
musicgen: Config,
t5: crate::t5_model::Config,
t5: t5::Config,
encodec: crate::encodec_model::Config,
}
@ -387,7 +388,7 @@ impl GenConfig {
pub fn small() -> Self {
Self {
musicgen: Config::musicgen_small(),
t5: t5_model::Config::musicgen_small(),
t5: t5::Config::musicgen_small(),
encodec: encodec_model::Config::musicgen_small(),
}
}
@ -399,7 +400,7 @@ impl MusicgenForConditionalGeneration {
}
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5_model::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;