Use the llama weight names for the Yi example. (#1381)

This commit is contained in:
Laurent Mazare
2023-11-27 20:42:52 +00:00
committed by GitHub
parent e2eb6590ed
commit 7c3cfd1086
2 changed files with 8 additions and 4 deletions

View File

@ -74,9 +74,9 @@ impl TextGeneration {
std::io::stdout().flush()?; std::io::stdout().flush()?;
let mut generated_tokens = 0usize; let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") { let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token, Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"), None => anyhow::bail!("cannot find the <|endoftext|> token"),
}; };
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..sample_len { for index in 0..sample_len {

View File

@ -277,8 +277,12 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln1"))?; let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let ln2 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("ln2"))?; let ln2 = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self { Ok(Self {
self_attn, self_attn,
mlp, mlp,