Add a repeat penalty to the llama2.c wasm example. (#709)

This commit is contained in:
Laurent Mazare
2023-09-01 20:32:28 +02:00
committed by GitHub
parent 1e5b2cc1d5
commit 2fef14cb14
4 changed files with 26 additions and 39 deletions

View File

@ -1,8 +1,8 @@
use crate::model::{Cache, Config, Llama};
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
use candle_nn::{ops::softmax, VarBuilder};
use rand::{distributions::Distribution, SeedableRng};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@ -56,40 +56,6 @@ pub struct Model {
pub tokenizer: Tokenizer,
}
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
temperature: Option<f64>,
}
impl LogitsProcessor {
pub fn new(seed: u64, temperature: Option<f64>) -> Self {
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
}
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr =
rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
Ok(next_token)
}
}
impl Model {
fn run(
&self,