mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00

* Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
164 lines
5.2 KiB
Rust
164 lines
5.2 KiB
Rust
use candle::{DType, Device, Tensor};
|
|
use candle_nn::VarBuilder;
|
|
use candle_transformers::generation::LogitsProcessor;
|
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
|
use candle_wasm_example_phi::console_log;
|
|
use js_sys::Date;
|
|
use serde::Deserialize;
|
|
use tokenizers::Tokenizer;
|
|
use wasm_bindgen::prelude::*;
|
|
|
|
enum SelectedModel {
|
|
MixFormer(MixFormer),
|
|
Quantized(QMixFormer),
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
pub struct Model {
|
|
model: SelectedModel,
|
|
tokenizer: Tokenizer,
|
|
logits_processor: LogitsProcessor,
|
|
tokens: Vec<u32>,
|
|
repeat_penalty: f32,
|
|
repeat_last_n: usize,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
|
|
|
pub struct ModelName {
|
|
pub _name_or_path: String,
|
|
}
|
|
|
|
#[wasm_bindgen]
|
|
impl Model {
|
|
#[wasm_bindgen(constructor)]
|
|
pub fn load(
|
|
weights: Vec<u8>,
|
|
tokenizer: Vec<u8>,
|
|
config: Vec<u8>,
|
|
quantized: bool,
|
|
) -> Result<Model, JsError> {
|
|
console_error_panic_hook::set_once();
|
|
console_log!("loading model");
|
|
let device = Device::Cpu;
|
|
let name: ModelName = serde_json::from_slice(&config)?;
|
|
let config: Config = serde_json::from_slice(&config)?;
|
|
|
|
console_log!("config loaded {:?}", name);
|
|
let tokenizer =
|
|
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
|
let start = Date::now();
|
|
console_log!("weights len: {:?}", weights.len());
|
|
let model = if quantized {
|
|
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
|
|
&weights, &device,
|
|
)?;
|
|
console_log!("weights loaded");
|
|
if name._name_or_path == "microsoft/phi-2" {
|
|
let model = QMixFormer::new_v2(&config, vb)?;
|
|
SelectedModel::Quantized(model)
|
|
} else {
|
|
let model = QMixFormer::new(&config, vb)?;
|
|
SelectedModel::Quantized(model)
|
|
}
|
|
} else {
|
|
let device = &Device::Cpu;
|
|
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;
|
|
let model = MixFormer::new(&config, vb)?;
|
|
SelectedModel::MixFormer(model)
|
|
};
|
|
console_log!("model loaded in {:?}s", (Date::now() - start) / 1000.);
|
|
let logits_processor = LogitsProcessor::new(299792458, None, None);
|
|
Ok(Self {
|
|
model,
|
|
tokenizer,
|
|
tokens: vec![],
|
|
logits_processor,
|
|
repeat_penalty: 1.,
|
|
repeat_last_n: 64,
|
|
})
|
|
}
|
|
#[wasm_bindgen]
|
|
pub fn init_with_prompt(
|
|
&mut self,
|
|
prompt: String,
|
|
temp: f64,
|
|
top_p: f64,
|
|
repeat_penalty: f32,
|
|
repeat_last_n: usize,
|
|
seed: u64,
|
|
) -> Result<String, JsError> {
|
|
match &mut self.model {
|
|
SelectedModel::MixFormer(m) => m.clear_kv_cache(),
|
|
SelectedModel::Quantized(m) => m.clear_kv_cache(),
|
|
};
|
|
let temp = if temp <= 0. { None } else { Some(temp) };
|
|
let top_p = if top_p <= 0. || top_p >= 1. {
|
|
None
|
|
} else {
|
|
Some(top_p)
|
|
};
|
|
self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
self.repeat_penalty = repeat_penalty;
|
|
self.repeat_last_n = repeat_last_n;
|
|
self.tokens.clear();
|
|
let tokens = self
|
|
.tokenizer
|
|
.encode(prompt, true)
|
|
.map_err(|m| JsError::new(&m.to_string()))?
|
|
.get_ids()
|
|
.to_vec();
|
|
let text = self
|
|
.process(&tokens)
|
|
.map_err(|m| JsError::new(&m.to_string()))?;
|
|
Ok(text)
|
|
}
|
|
#[wasm_bindgen]
|
|
pub fn next_token(&mut self) -> Result<String, JsError> {
|
|
let last_token = *self.tokens.last().unwrap();
|
|
let text = self
|
|
.process(&[last_token])
|
|
.map_err(|m| JsError::new(&m.to_string()))?;
|
|
Ok(text)
|
|
}
|
|
}
|
|
|
|
impl Model {
|
|
fn process(&mut self, tokens: &[u32]) -> candle::Result<String> {
|
|
let dev = Device::Cpu;
|
|
let input = Tensor::new(tokens, &dev)?.unsqueeze(0)?;
|
|
let logits = match &mut self.model {
|
|
SelectedModel::MixFormer(m) => m.forward(&input)?,
|
|
SelectedModel::Quantized(m) => m.forward(&input)?,
|
|
};
|
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
|
let logits = if self.repeat_penalty == 1. {
|
|
logits
|
|
} else {
|
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
candle_transformers::utils::apply_repeat_penalty(
|
|
&logits,
|
|
self.repeat_penalty,
|
|
&tokens[start_at..],
|
|
)?
|
|
};
|
|
|
|
let next_token = self.logits_processor.sample(&logits)?;
|
|
self.tokens.push(next_token);
|
|
let token = match self.tokenizer.decode(&[next_token], false) {
|
|
Ok(token) => token,
|
|
Err(e) => {
|
|
console_log!("error decoding token: {:?}", e);
|
|
"".to_string()
|
|
}
|
|
};
|
|
// console_log!("token: {:?}: {:?}", token, next_token);
|
|
Ok(token)
|
|
}
|
|
}
|
|
|
|
fn main() {
|
|
console_error_panic_hook::set_once();
|
|
}
|