mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Fix the musicgen example. (#724)
* Fix the musicgen example. * Retrieve the weights from the hub.
This commit is contained in:
@ -1,10 +1,8 @@
|
||||
// T5 Text Encoder
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::nn::{embedding, linear, Dropout, Embedding, HiddenAct, Linear, VarBuilder};
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Tensor, D};
|
||||
use candle_nn::Module;
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_no_bias, Activation, Embedding, Linear, Module, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -21,7 +19,7 @@ pub struct Config {
|
||||
dropout_rate: f64,
|
||||
layer_norm_epsilon: f64,
|
||||
initializer_factor: f64,
|
||||
feed_forward_proj: HiddenAct,
|
||||
feed_forward_proj: Activation,
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
use_cache: bool,
|
||||
@ -44,7 +42,7 @@ impl Default for Config {
|
||||
dropout_rate: 0.1,
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: HiddenAct::Relu,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
use_cache: true,
|
||||
@ -63,7 +61,7 @@ impl Config {
|
||||
d_model: 768,
|
||||
dropout_rate: 0.1,
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: HiddenAct::Relu,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -112,27 +110,23 @@ impl T5LayerNorm {
|
||||
struct T5DenseActDense {
|
||||
wi: Linear,
|
||||
wo: Linear,
|
||||
dropout: Dropout,
|
||||
act: HiddenAct,
|
||||
act: Activation,
|
||||
}
|
||||
|
||||
impl T5DenseActDense {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let wi = linear(cfg.d_model, cfg.d_ff, false, vb.pp("wi"))?;
|
||||
let wo = linear(cfg.d_ff, cfg.d_model, false, vb.pp("wo"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
||||
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||
Ok(Self {
|
||||
wi,
|
||||
wo,
|
||||
dropout,
|
||||
act: HiddenAct::Relu,
|
||||
act: Activation::Relu,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.wi.forward(xs)?;
|
||||
let xs = self.act.forward(&xs)?;
|
||||
let xs = self.dropout.forward(&xs)?;
|
||||
let xs = self.wo.forward(&xs)?;
|
||||
Ok(xs)
|
||||
}
|
||||
@ -142,7 +136,6 @@ impl T5DenseActDense {
|
||||
struct T5LayerFF {
|
||||
dense_relu_dense: T5DenseActDense,
|
||||
layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5LayerFF {
|
||||
@ -151,18 +144,16 @@ impl T5LayerFF {
|
||||
let dense_relu_dense = T5DenseActDense::load(vb.pp("DenseReluDense"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
dense_relu_dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = self.layer_norm.forward(xs)?;
|
||||
let ys = self.dense_relu_dense.forward(&ys)?;
|
||||
let xs = (xs + self.dropout.forward(&ys)?)?;
|
||||
let xs = (xs + ys)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
@ -181,10 +172,10 @@ struct T5Attention {
|
||||
impl T5Attention {
|
||||
fn load(h: bool, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||
let q = linear(cfg.d_model, inner_dim, false, vb.pp("q"))?;
|
||||
let k = linear(cfg.d_model, inner_dim, false, vb.pp("k"))?;
|
||||
let v = linear(cfg.d_model, inner_dim, false, vb.pp("v"))?;
|
||||
let o = linear(inner_dim, cfg.d_model, false, vb.pp("o"))?;
|
||||
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||
let relative_attention_bias = if h {
|
||||
let emb = embedding(
|
||||
cfg.relative_attention_num_buckets,
|
||||
@ -235,7 +226,6 @@ impl T5Attention {
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5LayerSelfAttention {
|
||||
@ -243,11 +233,9 @@ impl T5LayerSelfAttention {
|
||||
let self_attention = T5Attention::load(h, vb.pp("SelfAttention"), cfg)?;
|
||||
let layer_norm =
|
||||
T5LayerNorm::load(cfg.d_model, cfg.layer_norm_epsilon, vb.pp("layer_norm"))?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
@ -315,7 +303,6 @@ struct T5Stack {
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
final_layer_norm: T5LayerNorm,
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl T5Stack {
|
||||
@ -328,12 +315,10 @@ impl T5Stack {
|
||||
cfg.layer_norm_epsilon,
|
||||
vb.pp("final_layer_norm"),
|
||||
)?;
|
||||
let dropout = Dropout::new(cfg.dropout_rate);
|
||||
Ok(Self {
|
||||
block,
|
||||
shared: shared.clone(),
|
||||
final_layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
@ -341,12 +326,11 @@ impl T5Stack {
|
||||
let input_embeds = self.shared.as_ref().forward(input_ids)?;
|
||||
let (_b_sz, _seq_len) = input_embeds.dims2()?;
|
||||
|
||||
let mut hidden_states = self.dropout.forward(&input_embeds)?;
|
||||
let mut hidden_states = input_embeds;
|
||||
for block in self.block.iter() {
|
||||
hidden_states = block.forward(&hidden_states)?
|
||||
}
|
||||
let hidden_states = self.final_layer_norm.forward(&hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user