mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Llama2c WASM UI improvements (#732)
* pass seed, expose model seq_len * wip new llama2.c ui * final new UI example * small coppy * copy
This commit is contained in:
@ -58,6 +58,11 @@ impl Model {
|
||||
Err(e) => Err(JsError::new(&e.to_string())),
|
||||
}
|
||||
}
|
||||
#[wasm_bindgen]
|
||||
pub fn get_seq_len(&mut self) -> usize {
|
||||
let seq_len = self.inner.config.seq_len;
|
||||
seq_len
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn init_with_prompt(
|
||||
@ -65,6 +70,7 @@ impl Model {
|
||||
prompt: String,
|
||||
temp: f64,
|
||||
repeat_penalty: f32,
|
||||
seed: u64,
|
||||
) -> Result<String, JsError> {
|
||||
// First reset the cache.
|
||||
{
|
||||
@ -74,7 +80,7 @@ impl Model {
|
||||
}
|
||||
}
|
||||
let temp = if temp <= 0. { None } else { Some(temp) };
|
||||
self.logits_processor = LogitsProcessor::new(299792458, temp);
|
||||
self.logits_processor = LogitsProcessor::new(seed, temp);
|
||||
self.repeat_penalty = repeat_penalty;
|
||||
self.tokens.clear();
|
||||
let tokens = self
|
||||
|
@ -51,7 +51,7 @@ fn read_tensor<R: std::io::Read, S: Into<Shape>>(
|
||||
|
||||
pub struct Model {
|
||||
pub cache: Cache,
|
||||
config: Config,
|
||||
pub config: Config,
|
||||
pub llama: Llama,
|
||||
pub tokenizer: Tokenizer,
|
||||
}
|
||||
|
Reference in New Issue
Block a user