mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a repeat penalty to the llama2.c wasm example. (#709)
This commit is contained in:
@ -11,6 +11,7 @@ license.workspace = true
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
|
candle = { path = "../../candle-core", version = "0.2.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../../candle-nn", version = "0.2.1" }
|
candle-nn = { path = "../../candle-nn", version = "0.2.1" }
|
||||||
|
candle-transformers = { path = "../../candle-transformers", version = "0.2.1" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_wasm_example_llama2::worker::{LogitsProcessor, Model as M, ModelData};
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use candle_wasm_example_llama2::worker::{Model as M, ModelData};
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
@ -7,14 +8,26 @@ pub struct Model {
|
|||||||
inner: M,
|
inner: M,
|
||||||
logits_processor: LogitsProcessor,
|
logits_processor: LogitsProcessor,
|
||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
|
repeat_penalty: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
|
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
|
||||||
|
const REPEAT_LAST_N: usize = 64;
|
||||||
let dev = Device::Cpu;
|
let dev = Device::Cpu;
|
||||||
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
|
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
|
||||||
let logits = self.inner.llama.forward(&input, tokens.len())?;
|
let logits = self.inner.llama.forward(&input, tokens.len())?;
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = self.tokens.len().saturating_sub(REPEAT_LAST_N);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
self.tokens.push(next_token);
|
self.tokens.push(next_token);
|
||||||
@ -40,13 +53,19 @@ impl Model {
|
|||||||
inner,
|
inner,
|
||||||
logits_processor,
|
logits_processor,
|
||||||
tokens: vec![],
|
tokens: vec![],
|
||||||
|
repeat_penalty: 1.,
|
||||||
}),
|
}),
|
||||||
Err(e) => Err(JsError::new(&e.to_string())),
|
Err(e) => Err(JsError::new(&e.to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub fn init_with_prompt(&mut self, prompt: String, temp: f64) -> Result<String, JsError> {
|
pub fn init_with_prompt(
|
||||||
|
&mut self,
|
||||||
|
prompt: String,
|
||||||
|
temp: f64,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
) -> Result<String, JsError> {
|
||||||
// First reset the cache.
|
// First reset the cache.
|
||||||
{
|
{
|
||||||
let mut cache = self.inner.cache.kvs.lock().unwrap();
|
let mut cache = self.inner.cache.kvs.lock().unwrap();
|
||||||
@ -56,6 +75,7 @@ impl Model {
|
|||||||
}
|
}
|
||||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||||
self.logits_processor = LogitsProcessor::new(299792458, temp);
|
self.logits_processor = LogitsProcessor::new(299792458, temp);
|
||||||
|
self.repeat_penalty = repeat_penalty;
|
||||||
self.tokens.clear();
|
self.tokens.clear();
|
||||||
let tokens = self
|
let tokens = self
|
||||||
.inner
|
.inner
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
use crate::model::{Cache, Config, Llama};
|
use crate::model::{Cache, Config, Llama};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use candle::{DType, Device, IndexOp, Result, Shape, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
|
||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::VarBuilder;
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
@ -56,40 +56,6 @@ pub struct Model {
|
|||||||
pub tokenizer: Tokenizer,
|
pub tokenizer: Tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitsProcessor {
|
|
||||||
rng: rand::rngs::StdRng,
|
|
||||||
temperature: Option<f64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LogitsProcessor {
|
|
||||||
pub fn new(seed: u64, temperature: Option<f64>) -> Self {
|
|
||||||
Self {
|
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
|
||||||
temperature,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
|
||||||
let logits = logits.to_dtype(DType::F32)?;
|
|
||||||
let next_token = if let Some(temperature) = self.temperature {
|
|
||||||
let prs = softmax(&(&logits / temperature)?, D::Minus1)?;
|
|
||||||
let prs: Vec<f32> = prs.to_vec1()?;
|
|
||||||
let distr =
|
|
||||||
rand::distributions::WeightedIndex::new(prs).map_err(candle::Error::wrap)?;
|
|
||||||
distr.sample(&mut self.rng) as u32
|
|
||||||
} else {
|
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
|
||||||
logits_v
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
|
||||||
.map(|(i, _)| i as u32)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
Ok(next_token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn run(
|
fn run(
|
||||||
&self,
|
&self,
|
||||||
|
BIN
candle-wasm-examples/whisper/b.tgz
Normal file
BIN
candle-wasm-examples/whisper/b.tgz
Normal file
Binary file not shown.
Reference in New Issue
Block a user