mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Metal quantized modifications proposal.
- Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml).
This commit is contained in:

committed by
Nicolas Patry

parent
3a7304cb0d
commit
f97fcd4712
@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{
|
||||
use candle_wasm_example_t5::console_log;
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
const DEVICE: Device = Device::Cpu;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub struct ModelEncoder {
|
||||
@ -31,7 +32,7 @@ impl ModelConditionalGeneration {
|
||||
) -> Result<ModelConditionalGeneration, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
let tokenizer =
|
||||
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
@ -46,7 +47,7 @@ impl ModelConditionalGeneration {
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let input: ConditionalGenerationParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
let device = &Device::Cpu;
|
||||
let device = &DEVICE;
|
||||
self.model.clear_kv_cache();
|
||||
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
|
||||
let prompt = input.prompt;
|
||||
@ -128,7 +129,7 @@ impl ModelEncoder {
|
||||
) -> Result<ModelEncoder, JsError> {
|
||||
console_error_panic_hook::set_once();
|
||||
console_log!("loading model");
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights)?;
|
||||
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
|
||||
let mut config: Config = serde_json::from_slice(&config)?;
|
||||
config.use_cache = false;
|
||||
let tokenizer =
|
||||
@ -138,7 +139,7 @@ impl ModelEncoder {
|
||||
}
|
||||
|
||||
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
|
||||
let device = &Device::Cpu;
|
||||
let device = &DEVICE;
|
||||
let input: DecoderParams =
|
||||
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
|
||||
|
||||
|
Reference in New Issue
Block a user