Flan T5: Read lm_head when word embeddings are not tied (#903)

* Read lm_head when word embeddings are not tied

* Fix formatting

* Address comments
This commit is contained in:
Juarez Bochi
2023-09-19 14:36:47 -07:00
committed by GitHub
parent 67a486d18d
commit 05626ef492

View File

@ -18,12 +18,15 @@ fn default_use_cache() -> bool {
true
}
fn default_tie_word_embeddings() -> bool {
true
}
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
let result = Tensor::from_slice(&mask, (size, size), device)?;
Ok(result)
Tensor::from_slice(&mask, (size, size), device)
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
@ -50,6 +53,8 @@ pub struct Config {
initializer_factor: f64,
#[serde(default)]
feed_forward_proj: Activation,
#[serde(default = "default_tie_word_embeddings")]
tie_word_embeddings: bool,
#[serde(default = "default_is_decoder")]
is_decoder: bool,
is_encoder_decoder: bool,
@ -75,6 +80,7 @@ impl Default for Config {
layer_norm_epsilon: 1e-6,
initializer_factor: 1.0,
feed_forward_proj: Activation::Relu,
tie_word_embeddings: true,
is_decoder: false,
is_encoder_decoder: true,
use_cache: true,
@ -94,6 +100,7 @@ impl Config {
dropout_rate: 0.1,
eos_token_id: 1,
feed_forward_proj: Activation::Relu,
tie_word_embeddings: true,
initializer_factor: 1.0,
is_decoder: false,
is_encoder_decoder: true,
@ -611,6 +618,9 @@ impl T5EncoderModel {
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,
d_model: usize,
tie_word_embeddings: bool,
lm_head: Option<Linear>,
shared: Arc<Embedding>,
device: Device,
}
@ -618,6 +628,7 @@ pub struct T5ForConditionalGeneration {
impl T5ForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
assert!(cfg.is_encoder_decoder);
let d_model = cfg.d_model;
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let shared = Arc::new(shared);
@ -633,9 +644,23 @@ impl T5ForConditionalGeneration {
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
let tie_word_embeddings = cfg.tie_word_embeddings;
let lm_head = if tie_word_embeddings {
None
} else {
Some(linear_no_bias(
cfg.d_model,
cfg.vocab_size,
vb.pp("lm_head"),
)?)
};
Ok(Self {
encoder,
decoder,
d_model,
tie_word_embeddings,
lm_head,
shared,
device: vb.device().clone(),
})
@ -653,12 +678,23 @@ impl T5ForConditionalGeneration {
let decoder_output = self
.decoder
.forward(decoder_input_ids, Some(encoder_output))?;
let sequence_output = decoder_output
let scaling_factor = if self.tie_word_embeddings {
// Rescale output before projecting on vocab
// See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
(self.d_model as f64).sqrt()
} else {
1.0
};
let sequence_output = ((decoder_output
.narrow(1, decoder_output.dim(1)? - 1, 1)?
.squeeze(1)?;
// TODO: check cfg.tie_word_embeddings to load from model instead.
let lm_head_weights = self.shared.embeddings().t()?;
let output = sequence_output.matmul(&lm_head_weights)?;
.squeeze(1)?)
* scaling_factor)?;
let output = match self.lm_head {
None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
Some(ref lm_head) => lm_head.forward(&sequence_output)?,
};
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
Ok(output)
}