mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Flan T5: Read lm_head when word embeddings are not tied (#903)
* Read lm_head when word embeddings are not tied * Fix formatting * Address comments
This commit is contained in:
@ -18,12 +18,15 @@ fn default_use_cache() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_tie_word_embeddings() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let result = Tensor::from_slice(&mask, (size, size), device)?;
|
||||
Ok(result)
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
@ -50,6 +53,8 @@ pub struct Config {
|
||||
initializer_factor: f64,
|
||||
#[serde(default)]
|
||||
feed_forward_proj: Activation,
|
||||
#[serde(default = "default_tie_word_embeddings")]
|
||||
tie_word_embeddings: bool,
|
||||
#[serde(default = "default_is_decoder")]
|
||||
is_decoder: bool,
|
||||
is_encoder_decoder: bool,
|
||||
@ -75,6 +80,7 @@ impl Default for Config {
|
||||
layer_norm_epsilon: 1e-6,
|
||||
initializer_factor: 1.0,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
tie_word_embeddings: true,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
use_cache: true,
|
||||
@ -94,6 +100,7 @@ impl Config {
|
||||
dropout_rate: 0.1,
|
||||
eos_token_id: 1,
|
||||
feed_forward_proj: Activation::Relu,
|
||||
tie_word_embeddings: true,
|
||||
initializer_factor: 1.0,
|
||||
is_decoder: false,
|
||||
is_encoder_decoder: true,
|
||||
@ -611,6 +618,9 @@ impl T5EncoderModel {
|
||||
pub struct T5ForConditionalGeneration {
|
||||
encoder: T5Stack,
|
||||
decoder: T5Stack,
|
||||
d_model: usize,
|
||||
tie_word_embeddings: bool,
|
||||
lm_head: Option<Linear>,
|
||||
shared: Arc<Embedding>,
|
||||
device: Device,
|
||||
}
|
||||
@ -618,6 +628,7 @@ pub struct T5ForConditionalGeneration {
|
||||
impl T5ForConditionalGeneration {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
assert!(cfg.is_encoder_decoder);
|
||||
let d_model = cfg.d_model;
|
||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||
let shared = Arc::new(shared);
|
||||
|
||||
@ -633,9 +644,23 @@ impl T5ForConditionalGeneration {
|
||||
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
|
||||
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||
|
||||
let tie_word_embeddings = cfg.tie_word_embeddings;
|
||||
let lm_head = if tie_word_embeddings {
|
||||
None
|
||||
} else {
|
||||
Some(linear_no_bias(
|
||||
cfg.d_model,
|
||||
cfg.vocab_size,
|
||||
vb.pp("lm_head"),
|
||||
)?)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
d_model,
|
||||
tie_word_embeddings,
|
||||
lm_head,
|
||||
shared,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
@ -653,12 +678,23 @@ impl T5ForConditionalGeneration {
|
||||
let decoder_output = self
|
||||
.decoder
|
||||
.forward(decoder_input_ids, Some(encoder_output))?;
|
||||
let sequence_output = decoder_output
|
||||
|
||||
let scaling_factor = if self.tie_word_embeddings {
|
||||
// Rescale output before projecting on vocab
|
||||
// See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||
(self.d_model as f64).sqrt()
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
let sequence_output = ((decoder_output
|
||||
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
||||
.squeeze(1)?;
|
||||
// TODO: check cfg.tie_word_embeddings to load from model instead.
|
||||
let lm_head_weights = self.shared.embeddings().t()?;
|
||||
let output = sequence_output.matmul(&lm_head_weights)?;
|
||||
.squeeze(1)?)
|
||||
* scaling_factor)?;
|
||||
let output = match self.lm_head {
|
||||
None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
|
||||
Some(ref lm_head) => lm_head.forward(&sequence_output)?,
|
||||
};
|
||||
|
||||
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
|
||||
Ok(output)
|
||||
}
|
||||
|
Reference in New Issue
Block a user