mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Flan T5: Read lm_head when word embeddings are not tied (#903)
* Read lm_head when word embeddings are not tied * Fix formatting * Address comments
This commit is contained in:
@ -18,12 +18,15 @@ fn default_use_cache() -> bool {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_tie_word_embeddings() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||||
let mask: Vec<_> = (0..size)
|
let mask: Vec<_> = (0..size)
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
let result = Tensor::from_slice(&mask, (size, size), device)?;
|
Tensor::from_slice(&mask, (size, size), device)
|
||||||
Ok(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
@ -50,6 +53,8 @@ pub struct Config {
|
|||||||
initializer_factor: f64,
|
initializer_factor: f64,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
feed_forward_proj: Activation,
|
feed_forward_proj: Activation,
|
||||||
|
#[serde(default = "default_tie_word_embeddings")]
|
||||||
|
tie_word_embeddings: bool,
|
||||||
#[serde(default = "default_is_decoder")]
|
#[serde(default = "default_is_decoder")]
|
||||||
is_decoder: bool,
|
is_decoder: bool,
|
||||||
is_encoder_decoder: bool,
|
is_encoder_decoder: bool,
|
||||||
@ -75,6 +80,7 @@ impl Default for Config {
|
|||||||
layer_norm_epsilon: 1e-6,
|
layer_norm_epsilon: 1e-6,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
feed_forward_proj: Activation::Relu,
|
feed_forward_proj: Activation::Relu,
|
||||||
|
tie_word_embeddings: true,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
@ -94,6 +100,7 @@ impl Config {
|
|||||||
dropout_rate: 0.1,
|
dropout_rate: 0.1,
|
||||||
eos_token_id: 1,
|
eos_token_id: 1,
|
||||||
feed_forward_proj: Activation::Relu,
|
feed_forward_proj: Activation::Relu,
|
||||||
|
tie_word_embeddings: true,
|
||||||
initializer_factor: 1.0,
|
initializer_factor: 1.0,
|
||||||
is_decoder: false,
|
is_decoder: false,
|
||||||
is_encoder_decoder: true,
|
is_encoder_decoder: true,
|
||||||
@ -611,6 +618,9 @@ impl T5EncoderModel {
|
|||||||
pub struct T5ForConditionalGeneration {
|
pub struct T5ForConditionalGeneration {
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
decoder: T5Stack,
|
decoder: T5Stack,
|
||||||
|
d_model: usize,
|
||||||
|
tie_word_embeddings: bool,
|
||||||
|
lm_head: Option<Linear>,
|
||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
device: Device,
|
device: Device,
|
||||||
}
|
}
|
||||||
@ -618,6 +628,7 @@ pub struct T5ForConditionalGeneration {
|
|||||||
impl T5ForConditionalGeneration {
|
impl T5ForConditionalGeneration {
|
||||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
assert!(cfg.is_encoder_decoder);
|
assert!(cfg.is_encoder_decoder);
|
||||||
|
let d_model = cfg.d_model;
|
||||||
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
let shared = embedding(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
|
||||||
let shared = Arc::new(shared);
|
let shared = Arc::new(shared);
|
||||||
|
|
||||||
@ -633,9 +644,23 @@ impl T5ForConditionalGeneration {
|
|||||||
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
|
decoder_cfg.num_layers = cfg.num_decoder_layers.unwrap_or(cfg.num_layers);
|
||||||
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
|
let decoder = T5Stack::load(true, vb.pp("decoder"), &shared, &decoder_cfg)?;
|
||||||
|
|
||||||
|
let tie_word_embeddings = cfg.tie_word_embeddings;
|
||||||
|
let lm_head = if tie_word_embeddings {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(linear_no_bias(
|
||||||
|
cfg.d_model,
|
||||||
|
cfg.vocab_size,
|
||||||
|
vb.pp("lm_head"),
|
||||||
|
)?)
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
encoder,
|
encoder,
|
||||||
decoder,
|
decoder,
|
||||||
|
d_model,
|
||||||
|
tie_word_embeddings,
|
||||||
|
lm_head,
|
||||||
shared,
|
shared,
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
})
|
})
|
||||||
@ -653,12 +678,23 @@ impl T5ForConditionalGeneration {
|
|||||||
let decoder_output = self
|
let decoder_output = self
|
||||||
.decoder
|
.decoder
|
||||||
.forward(decoder_input_ids, Some(encoder_output))?;
|
.forward(decoder_input_ids, Some(encoder_output))?;
|
||||||
let sequence_output = decoder_output
|
|
||||||
|
let scaling_factor = if self.tie_word_embeddings {
|
||||||
|
// Rescale output before projecting on vocab
|
||||||
|
// See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
||||||
|
(self.d_model as f64).sqrt()
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
};
|
||||||
|
let sequence_output = ((decoder_output
|
||||||
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
.narrow(1, decoder_output.dim(1)? - 1, 1)?
|
||||||
.squeeze(1)?;
|
.squeeze(1)?)
|
||||||
// TODO: check cfg.tie_word_embeddings to load from model instead.
|
* scaling_factor)?;
|
||||||
let lm_head_weights = self.shared.embeddings().t()?;
|
let output = match self.lm_head {
|
||||||
let output = sequence_output.matmul(&lm_head_weights)?;
|
None => sequence_output.matmul(&self.shared.embeddings().t()?)?,
|
||||||
|
Some(ref lm_head) => lm_head.forward(&sequence_output)?,
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
|
// TODO: Rescale output before projecting on vocab? * (self.model_dim**-0.5)
|
||||||
Ok(output)
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user