mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add some text generation pipeline for falcon. (#98)
This commit is contained in:
@ -544,8 +544,9 @@ impl FalconDecoderLayer {
|
||||
#[derive(Debug)]
|
||||
pub struct Falcon {
|
||||
word_embeddings: Embedding,
|
||||
h: Vec<FalconDecoderLayer>,
|
||||
blocks: Vec<FalconDecoderLayer>,
|
||||
ln_f: LayerNorm,
|
||||
lm_head: Linear,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
@ -572,7 +573,7 @@ impl Falcon {
|
||||
"transformer.word_embeddings",
|
||||
vb,
|
||||
)?;
|
||||
let h = (0..cfg.num_hidden_layers)
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| FalconDecoderLayer::load(&format!("transformer.h.{i}"), vb, &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = LayerNorm::load(
|
||||
@ -581,10 +582,12 @@ impl Falcon {
|
||||
"transformer.ln_f",
|
||||
vb,
|
||||
)?;
|
||||
let lm_head = Linear::load(cfg.hidden_size, cfg.vocab_size, false, "lm_head", vb)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
h,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
@ -593,10 +596,12 @@ impl Falcon {
|
||||
let (b_sz, seq_len) = input_ids.shape().r2()?;
|
||||
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(&input_ids.device())?;
|
||||
for block in self.h.iter_mut() {
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask)?;
|
||||
}
|
||||
let hidden_state = self.ln_f.forward(&hidden_state)?;
|
||||
Ok(hidden_state)
|
||||
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user