Add an eval mode to llama2-c (#288)

* Add an eval mode to llama2-c.

* Encode line by line.

* Get the eval to run.
This commit is contained in:
Laurent Mazare
2023-07-31 17:22:14 +01:00
committed by GitHub
parent 1064b9b031
commit 9ae1f6afee
2 changed files with 87 additions and 35 deletions

View File

@ -286,18 +286,10 @@ pub struct Llama {
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
pub config: Config,
}
impl Llama {
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
Self {
wte,
blocks,
ln_f,
lm_head,
}
}
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
@ -305,18 +297,23 @@ impl Llama {
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(&cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let ln_f = RmsNorm::load(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();
Ok(Self::new(wte, blocks, norm, lm_head))
Ok(Self {
wte,
blocks,
ln_f,
lm_head,
config: cfg,
})
}
}