From 05626ef492300a4f99c87555304ec863071722d5 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Tue, 19 Sep 2023 14:36:47 -0700 Subject: [PATCH] Flan T5: Read lm_head when word embeddings are not tied (#903) * Read lm_head when word embeddings are not tied * Fix formatting * Address comments --- candle-transformers/src/models/t5.rs | 50 ++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index b1f3a3aa..fd2720d3 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -18,12 +18,15 @@ fn default_use_cache() -> bool { true } +fn default_tie_word_embeddings() -> bool { + true +} + fn get_mask(size: usize, device: &Device) -> Result { let mask: Vec<_> = (0..size) .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) .collect(); - let result = Tensor::from_slice(&mask, (size, size), device)?; - Ok(result) + Tensor::from_slice(&mask, (size, size), device) } fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { @@ -50,6 +53,8 @@ pub struct Config { initializer_factor: f64, #[serde(default)] feed_forward_proj: Activation, + #[serde(default = "default_tie_word_embeddings")] + tie_word_embeddings: bool, #[serde(default = "default_is_decoder")] is_decoder: bool, is_encoder_decoder: bool, @@ -75,6 +80,7 @@ impl Default for Config { layer_norm_epsilon: 1e-6, initializer_factor: 1.0, feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, is_decoder: false, is_encoder_decoder: true, use_cache: true, @@ -94,6 +100,7 @@ impl Config { dropout_rate: 0.1, eos_token_id: 1, feed_forward_proj: Activation::Relu, + tie_word_embeddings: true, initializer_factor: 1.0, is_decoder: false, is_encoder_decoder: true, @@ -611,6 +618,9 @@ impl T5EncoderModel { pub struct T5ForConditionalGeneration { encoder: T5Stack, decoder: T5Stack, + d_model: usize, + tie_word_embeddings: bool, + lm_head: Option, shared: Arc, device: Device, } @@ -618,6 +628,7 @@ pub struct T5ForConditionalGeneration { impl T5ForConditionalGeneration { pub fn load(vb: VarBuilder, cfg: &Config) -> Result { assert!(cfg.is_encoder_decoder); + let d_model = cfg.d_model; let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; let shared = Arc::new(shared); @@ -633,9 +644,23 @@ impl T5ForConditionalGeneration { decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers); let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?; + let tie_word_embeddings = cfg.tie_word_embeddings; + let lm_head = if tie_word_embeddings { + None + } else { + Some(linear_no_bias( + cfg.d_model, + cfg.vocab_size, + vb.pp("lm_head"), + )?) + }; + Ok(Self { encoder, decoder, + d_model, + tie_word_embeddings, + lm_head, shared, device: vb.device().clone(), }) @@ -653,12 +678,23 @@ impl T5ForConditionalGeneration { let decoder_output = self .decoder .forward(decoder_input_ids, Some(encoder_output))?; - let sequence_output = decoder_output + + let scaling_factor = if self.tie_word_embeddings { + // Rescale output before projecting on vocab + // See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + (self.d_model as f64).sqrt() + } else { + 1.0 + }; + let sequence_output = ((decoder_output .narrow(1, decoder_output.dim(1)? - 1, 1)? - .squeeze(1)?; - // TODO: check cfg.tie_word_embeddings to load from model instead. - let lm_head_weights = self.shared.embeddings().t()?; - let output = sequence_output.matmul(&lm_head_weights)?; + .squeeze(1)?) + * scaling_factor)?; + let output = match self.lm_head { + None => sequence_output.matmul(&self.shared.embeddings().t()?)?, + Some(ref lm_head) => lm_head.forward(&sequence_output)?, + }; + // TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5) Ok(output) }