mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Add a repeat penalty to the llama2.c wasm example. (#709)
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
use candle::{Device, Tensor};
|
||||
use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_wasm_example_llama2::worker::{Model as M, ModelData};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
#[wasm_bindgen]
|
||||
@ -7,14 +8,26 @@ pub struct Model {
|
||||
inner: M,
|
||||
logits_processor: LogitsProcessor,
|
||||
tokens: Vec<u32>,
|
||||
repeat_penalty: f32,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
|
||||
const REPEAT_LAST_N: usize = 64;
|
||||
let dev = Device::Cpu;
|
||||
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
|
||||
let logits = self.inner.llama.forward(&input, tokens.len())?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = self.tokens.len().saturating_sub(REPEAT_LAST_N);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
self.tokens.push(next_token);
|
||||
@ -40,13 +53,19 @@ impl Model {
|
||||
inner,
|
||||
logits_processor,
|
||||
tokens: vec![],
|
||||
repeat_penalty: 1.,
|
||||
}),
|
||||
Err(e) => Err(JsError::new(&e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
|
||||
pub fn init_with_prompt(
|
||||
&mut self,
|
||||
prompt: String,
|
||||
temp: f64,
|
||||
repeat_penalty: f32,
|
||||
) -> Result<String, JsError> {
|
||||
// First reset the cache.
|
||||
{
|
||||
let mut cache = self.inner.cache.kvs.lock().unwrap();
|
||||
@ -56,6 +75,7 @@ impl Model {
|
||||
}
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
self.logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
self.repeat_penalty = repeat_penalty;
|
||||
self.tokens.clear();
|
||||
let tokens = self
|
||||
.inner
|
||||
|
Reference in New Issue
Block a user