Minor cleanup (#1194)

* Add some missing backtraces.

* Small cleanup.
This commit is contained in:
Laurent Mazare
2023-10-27 14:08:29 +01:00
committed by GitHub
parent 9b1158b315
commit 916619f70b

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use candle_nn::{embedding, linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -57,20 +57,6 @@ impl Cache {
}
}
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
Ok(Embedding::new(embeddings, cfg.dim))
}
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -198,7 +184,7 @@ impl Mlp {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
@ -283,7 +269,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)