mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Quantized GGUF style (#1523)
* Metal quantized modifications proposal. - Add a device param, wherever needed. - Create new QMetal storage thing that implements QuantizedType. - Update everywhere needed. Fix Python. Fixing examples. Fix: fmt + clippy + stub. Moving everything around. Only missing the actual implems. Fixing everything + adding dequantized kernels. More work. Fixing matmul. Fmt + Clippy Some clippy fixes. Working state. Q2K Metal -> Bugged (also present in GGML). Q4K CPU -> Bugged (present previously, new test catch it). Q5K CPU -> Bugged (present previously). Q8_1 Both -> Never really implemented it seems Q8K metal -> Never implemented in metal Fixing Q2K bug (present in ggml). * Cleanup. * Fix the rebase. * Removing the fences speeds everything up and *is* correct this time... * Cleanup the fence. * After rebase. * Bad code removal. * Rebase after phi2 merge + fix replit default to CPU. * Making the CI happy. * More happy tests. --------- Co-authored-by: Nicolas Patry <nicolas@Nicolass-MacBook-Pro.local>
This commit is contained in:
@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let config = blip::Config::image_captioning_large();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
|
@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
.shape()
|
||||
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_real",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
.dequantize(&device)?;
|
||||
let freq_cis_imag = vb
|
||||
.get(
|
||||
(config.seq_len, config.head_size() / 2),
|
||||
"rot.freq_cis_imag",
|
||||
)?
|
||||
.dequantize(&candle::Device::Cpu)?;
|
||||
.dequantize(&device)?;
|
||||
|
||||
let fake_vb = candle_nn::VarBuilder::from_tensors(
|
||||
[
|
||||
@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.into_iter()
|
||||
.collect(),
|
||||
candle::DType::F32,
|
||||
&candle::Device::Cpu,
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
|
@ -244,13 +244,14 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::config_7b_v0_1(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let model = QMistral::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
(Model::Quantized(model), device)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
|
@ -307,18 +307,21 @@ fn main() -> Result<()> {
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
let config = config();
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
let model = match args.model {
|
||||
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
|
||||
_ => QMixFormer::new(&config, vb)?,
|
||||
};
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = match args.model {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
@ -334,8 +337,7 @@ fn main() -> Result<()> {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new(&config, vb)?)
|
||||
}
|
||||
};
|
||||
(model, device)
|
||||
}
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -132,7 +132,8 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
|
||||
let device = Device::Cpu;
|
||||
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -361,6 +361,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(false)?;
|
||||
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
@ -369,7 +370,7 @@ fn main() -> anyhow::Result<()> {
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
@ -377,15 +378,16 @@ fn main() -> anyhow::Result<()> {
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file)?
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
}
|
||||
Some("ggml" | "bin") | Some(_) | None => {
|
||||
let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let model = ggml_file::Content::read(&mut file, &device)
|
||||
.map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensors.iter() {
|
||||
let elem_count = tensor.shape().elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
||||
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
@ -486,7 +488,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
@ -507,7 +509,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
|
@ -236,16 +236,15 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||
(model, Device::Cpu)
|
||||
let model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
|
||||
Model::Q(Q::new(&config, vb.pp("transformer"))?)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||
(model, device)
|
||||
Model::M(M::new(&config, vb.pp("transformer"))?)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -234,13 +234,14 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let model = QStableLM::new(&config, vb)?;
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
|
@ -557,8 +557,10 @@ fn main() -> Result<()> {
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let mut model = if args.quantized {
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&weights_filename,
|
||||
&device,
|
||||
)?;
|
||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
|
||||
} else {
|
||||
let vb =
|
||||
|
Reference in New Issue
Block a user