mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Fix the musicgen example. (#724)
* Fix the musicgen example. * Retrieve the weights from the hub.
This commit is contained in:
@ -1,7 +1,6 @@
|
|||||||
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
|
use crate::nn::conv1d_weight_norm;
|
||||||
use anyhow::Result;
|
use candle::{DType, IndexOp, Result, Tensor};
|
||||||
use candle::{DType, IndexOp, Tensor};
|
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
||||||
use candle_nn::Module;
|
|
||||||
|
|
||||||
// Encodec Model
|
// Encodec Model
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||||
@ -183,7 +182,7 @@ impl EncodecResidualVectorQuantizer {
|
|||||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||||
if codes.dim(0)? != self.layers.len() {
|
if codes.dim(0)? != self.layers.len() {
|
||||||
anyhow::bail!(
|
candle::bail!(
|
||||||
"codes shape {:?} does not match the number of quantization layers {}",
|
"codes shape {:?} does not match the number of quantization layers {}",
|
||||||
codes.shape(),
|
codes.shape(),
|
||||||
self.layers.len()
|
self.layers.len()
|
||||||
@ -321,7 +320,7 @@ impl EncodecResnetBlock {
|
|||||||
let h = dim / cfg.compress;
|
let h = dim / cfg.compress;
|
||||||
let mut layer = Layer::new(vb.pp("block"));
|
let mut layer = Layer::new(vb.pp("block"));
|
||||||
if dilations.len() != 2 {
|
if dilations.len() != 2 {
|
||||||
anyhow::bail!("expected dilations of size 2")
|
candle::bail!("expected dilations of size 2")
|
||||||
}
|
}
|
||||||
// TODO: Apply dilations!
|
// TODO: Apply dilations!
|
||||||
layer.inc();
|
layer.inc();
|
||||||
|
@ -16,11 +16,12 @@ mod nn;
|
|||||||
mod t5_model;
|
mod t5_model;
|
||||||
|
|
||||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||||
use nn::VarBuilder;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::DType;
|
use candle::DType;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
|
||||||
const DTYPE: DType = DType::F32;
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
@ -33,11 +34,11 @@ struct Args {
|
|||||||
|
|
||||||
/// The model weight file, in safetensor format.
|
/// The model weight file, in safetensor format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: String,
|
model: Option<String>,
|
||||||
|
|
||||||
/// The tokenizer config.
|
/// The tokenizer config.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer: String,
|
tokenizer: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -45,10 +46,26 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer).map_err(E::msg)?;
|
let tokenizer = match args.tokenizer {
|
||||||
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("facebook/musicgen-small".to_string())
|
||||||
|
.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let mut tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||||
|
|
||||||
let model = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.repo(Repo::with_revision(
|
||||||
|
"facebook/musicgen-small".to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
"refs/pr/13".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
|
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
||||||
let model = model.deserialize()?;
|
let model = model.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
||||||
let config = GenConfig::small();
|
let config = GenConfig::small();
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
use crate::nn::{
|
|
||||||
embedding, layer_norm, linear, Embedding, HiddenAct, LayerNorm, Linear, VarBuilder,
|
|
||||||
};
|
|
||||||
use crate::{encodec_model, t5_model};
|
use crate::{encodec_model, t5_model};
|
||||||
use anyhow::Result;
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle::{DType, Device, Tensor, D};
|
use candle_nn::{
|
||||||
use candle_nn::Module;
|
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||||
|
VarBuilder,
|
||||||
|
};
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -16,7 +15,7 @@ pub struct Config {
|
|||||||
num_attention_heads: usize,
|
num_attention_heads: usize,
|
||||||
layerdrop: f64,
|
layerdrop: f64,
|
||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
activation_function: HiddenAct,
|
activation_function: Activation,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
dropout: f64,
|
dropout: f64,
|
||||||
attention_dropout: f64,
|
attention_dropout: f64,
|
||||||
@ -40,7 +39,7 @@ impl Default for Config {
|
|||||||
num_attention_heads: 16,
|
num_attention_heads: 16,
|
||||||
layerdrop: 0.0,
|
layerdrop: 0.0,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||||
hidden_size: 1024,
|
hidden_size: 1024,
|
||||||
dropout: 0.1,
|
dropout: 0.1,
|
||||||
attention_dropout: 0.0,
|
attention_dropout: 0.0,
|
||||||
@ -66,7 +65,7 @@ impl Config {
|
|||||||
num_attention_heads: 16,
|
num_attention_heads: 16,
|
||||||
layerdrop: 0.0,
|
layerdrop: 0.0,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
activation_function: HiddenAct::Gelu, // TODO: Handle old style gelu.
|
activation_function: Activation::Gelu, // TODO: Handle old style gelu.
|
||||||
hidden_size: 1024,
|
hidden_size: 1024,
|
||||||
dropout: 0.1,
|
dropout: 0.1,
|
||||||
attention_dropout: 0.0,
|
attention_dropout: 0.0,
|
||||||
@ -128,7 +127,7 @@ impl MusicgenSinusoidalPositionalEmbedding {
|
|||||||
if seq_len > self.weights.dim(0)? {
|
if seq_len > self.weights.dim(0)? {
|
||||||
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
self.weights = get_embedding(seq_len, self.embedding_dim)?
|
||||||
}
|
}
|
||||||
Ok(self.weights.narrow(0, 0, seq_len)?)
|
self.weights.narrow(0, 0, seq_len)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,10 +148,10 @@ impl MusicgenAttention {
|
|||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let head_dim = h / num_heads;
|
let head_dim = h / num_heads;
|
||||||
let k_proj = linear(h, h, false, vb.pp("k_proj"))?;
|
let k_proj = linear_no_bias(h, h, vb.pp("k_proj"))?;
|
||||||
let v_proj = linear(h, h, false, vb.pp("v_proj"))?;
|
let v_proj = linear_no_bias(h, h, vb.pp("v_proj"))?;
|
||||||
let q_proj = linear(h, h, false, vb.pp("q_proj"))?;
|
let q_proj = linear_no_bias(h, h, vb.pp("q_proj"))?;
|
||||||
let out_proj = linear(h, h, false, vb.pp("out_proj"))?;
|
let out_proj = linear_no_bias(h, h, vb.pp("out_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
scaling: 1. / (head_dim as f64).sqrt(),
|
scaling: 1. / (head_dim as f64).sqrt(),
|
||||||
is_decoder: true,
|
is_decoder: true,
|
||||||
@ -209,7 +208,7 @@ struct MusicgenDecoderLayer {
|
|||||||
fc1: Linear,
|
fc1: Linear,
|
||||||
fc2: Linear,
|
fc2: Linear,
|
||||||
final_layer_norm: LayerNorm,
|
final_layer_norm: LayerNorm,
|
||||||
activation_fn: HiddenAct,
|
activation_fn: Activation,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MusicgenDecoderLayer {
|
impl MusicgenDecoderLayer {
|
||||||
@ -219,8 +218,8 @@ impl MusicgenDecoderLayer {
|
|||||||
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
let self_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||||
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
|
let encoder_attn = MusicgenAttention::load(vb.pp("encoder_attn"), cfg)?;
|
||||||
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
let encoder_attn_layer_norm = layer_norm(h, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
let fc1 = linear(h, cfg.ffn_dim, false, vb.pp("fc1"))?;
|
let fc1 = linear_no_bias(h, cfg.ffn_dim, vb.pp("fc1"))?;
|
||||||
let fc2 = linear(cfg.ffn_dim, h, false, vb.pp("fc2"))?;
|
let fc2 = linear_no_bias(cfg.ffn_dim, h, vb.pp("fc2"))?;
|
||||||
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
|
let final_layer_norm = layer_norm(h, 1e-5, vb.pp("final_layer_norm"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
@ -342,7 +341,7 @@ impl MusicgenForCausalLM {
|
|||||||
let h = cfg.hidden_size;
|
let h = cfg.hidden_size;
|
||||||
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
let decoder = MusicgenDecoder::load(vb.pp("model.decoder"), cfg)?;
|
||||||
let lm_heads = (0..cfg.num_codebooks)
|
let lm_heads = (0..cfg.num_codebooks)
|
||||||
.map(|i| linear(h, cfg.vocab_size, false, vb.pp(&format!("lm_heads.{i}"))))
|
.map(|i| linear_no_bias(h, cfg.vocab_size, vb.pp(&format!("lm_heads.{i}"))))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
decoder,
|
decoder,
|
||||||
@ -358,7 +357,7 @@ impl MusicgenForCausalLM {
|
|||||||
let lm_logits = self
|
let lm_logits = self
|
||||||
.lm_heads
|
.lm_heads
|
||||||
.iter()
|
.iter()
|
||||||
.map(|h| Ok(h.forward(&hidden_states)?))
|
.map(|h| h.forward(&hidden_states))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
|
let lm_logits = Tensor::stack(&lm_logits, 1)?.reshape((
|
||||||
b_sz * self.num_codebooks,
|
b_sz * self.num_codebooks,
|
||||||
|
@ -1,62 +1,5 @@
|
|||||||
use anyhow::Result;
|
use candle::Result;
|
||||||
use candle::Tensor;
|
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 5000;
|
|
||||||
|
|
||||||
pub type VarBuilder<'a> = candle_nn::VarBuilder<'a>;
|
|
||||||
pub type Linear = candle_nn::Linear;
|
|
||||||
|
|
||||||
pub fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
let bias = if bias {
|
|
||||||
Some(vb.get(size2, "bias")?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Linear::new(weight, bias))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type LayerNorm = candle_nn::LayerNorm;
|
|
||||||
|
|
||||||
pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
|
||||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
|
||||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
|
||||||
(Err(err), _) | (_, Err(err)) => {
|
|
||||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
|
||||||
(weight, bias)
|
|
||||||
} else {
|
|
||||||
return Err(err.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(LayerNorm::new(weight, bias, eps))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Dropout {
|
|
||||||
pr: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Dropout {
|
|
||||||
pub fn new(pr: f64) -> Self {
|
|
||||||
Self { pr }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
// TODO
|
|
||||||
Ok(x.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type Embedding = candle_nn::Embedding;
|
|
||||||
|
|
||||||
pub fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
|
||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type Conv1d = candle_nn::Conv1d;
|
|
||||||
pub type Conv1dConfig = candle_nn::Conv1dConfig;
|
|
||||||
|
|
||||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||||
// does not apply to training.
|
// does not apply to training.
|
||||||
@ -75,17 +18,3 @@ pub fn conv1d_weight_norm(
|
|||||||
let bias = vb.get(out_c, "bias")?;
|
let bias = vb.get(out_c, "bias")?;
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn conv1d(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
config: Conv1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
|
||||||
let weight = vb.get((out_c, in_c, kernel_size), "weight")?;
|
|
||||||
let bias = vb.get(out_c, "bias")?;
|
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub type HiddenAct = candle_nn::Activation;
|
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
// T5 Text Encoder
|
// T5 Text Encoder
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
use candle::{DType, Result, Tensor, D};
|
||||||
use anyhow::Result;
|
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
|
||||||
use candle::{DType, Tensor, D};
|
|
||||||
use candle_nn::Module;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
@ -21,7 +19,7 @@ pub struct Config {
|
|||||||
dropout_rate: f64,
|
dropout_rate: f64,
|
||||||
layer_norm_epsilon: f64,
|
layer_norm_epsilon: f64,
|
||||||
initializer_factor: f64,
|
initializer_factor: f64,
|
||||||
feed_forward_proj: HiddenAct,
|
feed_forward_proj: Activation,
|
||||||
is_decoder: bool,
|
is_decoder: bool,
|
||||||
is_encoder_decoder: bool,
|
is_encoder_decoder: bool,
|
||||||
use_cache: bool,
|
use_cache: bool,
|
||||||
@ -44,7 +42,7 @@ impl Default for Config {
|
|||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
feed_forward_proj: HiddenAct::Relu,
|
feed_forward_proj: Activation::Relu,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
@ -63,7 +61,7 @@ impl Config {
|
|||||||
d_model: 768,
|
d_model: 768,
|
||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
feed_forward_proj: HiddenAct::Relu,
|
feed_forward_proj: Activation::Relu,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
@ -112,27 +110,23 @@ impl T5LayerNorm {
|
|||||||
struct T5DenseActDense {
|
struct T5DenseActDense {
|
||||||
wi: Linear,
|
wi: Linear,
|
||||||
wo: Linear,
|
wo: Linear,
|
||||||
dropout: Dropout,
|
act: Activation,
|
||||||
act: HiddenAct,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5DenseActDense {
|
impl T5DenseActDense {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
|
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
||||||
let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
|
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wi,
|
wi,
|
||||||
wo,
|
wo,
|
||||||
dropout,
|
act: Activation::Relu,
|
||||||
act: HiddenAct::Relu,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = self.wi.forward(xs)?;
|
let xs = self.wi.forward(xs)?;
|
||||||
let xs = self.act.forward(&xs)?;
|
let xs = self.act.forward(&xs)?;
|
||||||
let xs = self.dropout.forward(&xs)?;
|
|
||||||
let xs = self.wo.forward(&xs)?;
|
let xs = self.wo.forward(&xs)?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
@ -142,7 +136,6 @@ impl T5DenseActDense {
|
|||||||
struct T5LayerFF {
|
struct T5LayerFF {
|
||||||
dense_relu_dense: T5DenseActDense,
|
dense_relu_dense: T5DenseActDense,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
dropout: Dropout,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerFF {
|
impl T5LayerFF {
|
||||||
@ -151,18 +144,16 @@ impl T5LayerFF {
|
|||||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
||||||
let layer_norm =
|
let layer_norm =
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense_relu_dense,
|
dense_relu_dense,
|
||||||
layer_norm,
|
layer_norm,
|
||||||
dropout,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let ys = self.layer_norm.forward(xs)?;
|
let ys = self.layer_norm.forward(xs)?;
|
||||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
let ys = self.dense_relu_dense.forward(&ys)?;
|
||||||
let xs = (xs + self.dropout.forward(&ys)?)?;
|
let xs = (xs + ys)?;
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -181,10 +172,10 @@ struct T5Attention {
|
|||||||
impl T5Attention {
|
impl T5Attention {
|
||||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||||
let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
|
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||||
let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
|
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||||
let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
|
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||||
let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
|
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||||
let relative_attention_bias = if h {
|
let relative_attention_bias = if h {
|
||||||
let emb = embedding(
|
let emb = embedding(
|
||||||
cfg.relative_attention_num_buckets,
|
cfg.relative_attention_num_buckets,
|
||||||
@ -235,7 +226,6 @@ impl T5Attention {
|
|||||||
struct T5LayerSelfAttention {
|
struct T5LayerSelfAttention {
|
||||||
self_attention: T5Attention,
|
self_attention: T5Attention,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
dropout: Dropout,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5LayerSelfAttention {
|
impl T5LayerSelfAttention {
|
||||||
@ -243,11 +233,9 @@ impl T5LayerSelfAttention {
|
|||||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||||
let layer_norm =
|
let layer_norm =
|
||||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attention,
|
self_attention,
|
||||||
layer_norm,
|
layer_norm,
|
||||||
dropout,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,7 +303,6 @@ struct T5Stack {
|
|||||||
block: Vec<T5Block>,
|
block: Vec<T5Block>,
|
||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
final_layer_norm: T5LayerNorm,
|
final_layer_norm: T5LayerNorm,
|
||||||
dropout: Dropout,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T5Stack {
|
impl T5Stack {
|
||||||
@ -328,12 +315,10 @@ impl T5Stack {
|
|||||||
cfg.layer_norm_epsilon,
|
cfg.layer_norm_epsilon,
|
||||||
vb.pp("final_layer_norm"),
|
vb.pp("final_layer_norm"),
|
||||||
)?;
|
)?;
|
||||||
let dropout = Dropout::new(cfg.dropout_rate);
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
block,
|
block,
|
||||||
shared: shared.clone(),
|
shared: shared.clone(),
|
||||||
final_layer_norm,
|
final_layer_norm,
|
||||||
dropout,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -341,12 +326,11 @@ impl T5Stack {
|
|||||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||||
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
||||||
|
|
||||||
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
let mut hidden_states = input_embeds;
|
||||||
for block in self.block.iter() {
|
for block in self.block.iter() {
|
||||||
hidden_states = block.forward(&hidden_states)?
|
hidden_states = block.forward(&hidden_states)?
|
||||||
}
|
}
|
||||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
|
||||||
Ok(hidden_states)
|
Ok(hidden_states)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user