Metal quantized modifications proposal.

- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.

Fix Python.

Fixing examples.

Fix: fmt + clippy + stub.

Moving everything around.

Only missing the actual implems.

Fixing everything + adding dequantized kernels.

More work.

Fixing matmul.

Fmt + Clippy

Some clippy fixes.

Working state.

Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal

Fixing Q2K bug (present in ggml).
This commit is contained in:
Nicolas Patry
2023-11-18 14:49:30 +01:00
committed by Nicolas Patry
parent 3a7304cb0d
commit f97fcd4712
32 changed files with 6408 additions and 420 deletions

View File

@ -61,7 +61,7 @@ impl Model {
let start = Date::now();
let model: SelectedModel = if quantized {
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
SelectedModel::Q(model)
} else {

View File

@ -41,6 +41,7 @@ impl Model {
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let device = Device::Cpu;
let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?;
@ -50,8 +51,9 @@ impl Model {
let start = Date::now();
console_log!("weights len: {:?}", weights.len());
let model = if quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&weights, &device,
)?;
console_log!("weights loaded");
if name._name_or_path == "microsoft/phi-2" {
let model = QMixFormer::new_v2(&config, vb)?;

View File

@ -7,6 +7,7 @@ pub use candle_transformers::models::quantized_t5::{
use candle_wasm_example_t5::console_log;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
const DEVICE: Device = Device::Cpu;
#[wasm_bindgen]
pub struct ModelEncoder {
@ -31,7 +32,7 @@ impl ModelConditionalGeneration {
) -> Result<ModelConditionalGeneration, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights)?;
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
@ -46,7 +47,7 @@ impl ModelConditionalGeneration {
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: ConditionalGenerationParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let device = &Device::Cpu;
let device = &DEVICE;
self.model.clear_kv_cache();
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
let prompt = input.prompt;
@ -128,7 +129,7 @@ impl ModelEncoder {
) -> Result<ModelEncoder, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights)?;
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?;
let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false;
let tokenizer =
@ -138,7 +139,7 @@ impl ModelEncoder {
}
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let device = &Device::Cpu;
let device = &DEVICE;
let input: DecoderParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;

View File

@ -315,6 +315,7 @@ impl Decoder {
let model = if md.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights,
&device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else {