Add some text generation pipeline for falcon. (#98)

This commit is contained in:
Laurent Mazare
2023-07-07 06:34:22 +01:00
committed by GitHub
parent 2b8e8c9f14
commit bac4ef40f3
2 changed files with 93 additions and 16 deletions

View File

@ -544,8 +544,9 @@ impl FalconDecoderLayer {
#[derive(Debug)]
pub struct Falcon {
word_embeddings: Embedding,
h: Vec<FalconDecoderLayer>,
blocks: Vec<FalconDecoderLayer>,
ln_f: LayerNorm,
lm_head: Linear,
config: Config,
}
@ -572,7 +573,7 @@ impl Falcon {
"transformer.word_embeddings",
vb,
)?;
let h = (0..cfg.num_hidden_layers)
let blocks = (0..cfg.num_hidden_layers)
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
.collect::<Result<Vec<_>>>()?;
let ln_f = LayerNorm::load(
@ -581,10 +582,12 @@ impl Falcon {
"transformer.ln_f",
vb,
)?;
let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
Ok(Self {
word_embeddings,
h,
blocks,
ln_f,
lm_head,
config: cfg,
})
}
@ -593,10 +596,12 @@ impl Falcon {
let (b_sz, seq_len) = input_ids.shape().r2()?;
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
for block in self.h.iter_mut() {
for block in self.blocks.iter_mut() {
hidden_state = block.forward(&hidden_state, &causal_mask)?;
}
let hidden_state = self.ln_f.forward(&hidden_state)?;
Ok(hidden_state)
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
Ok(logits)
}
}