Add the quantized mixformer model. (#953)

* Add the quantized mixformer model.

* Add the quantized option in the phi example.
This commit is contained in:
Laurent Mazare
2023-09-24 15:03:48 +01:00
committed by GitHub
parent e15862cfdb
commit 0007ae9c11
6 changed files with 418 additions and 48 deletions

View File

@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as Model};
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@ -15,6 +16,11 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
MixFormer(MixFormer),
Quantized(QMixFormer),
}
struct TextGeneration {
model: Model,
device: Device,
@ -58,7 +64,10 @@ impl TextGeneration {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = match &mut self.model {
Model::MixFormer(m) => m.forward(&input)?,
Model::Quantized(m) => m.forward(&input)?,
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = self.logits_processor.sample(&logits)?;
@ -115,6 +124,9 @@ struct Args {
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
quantized: bool,
}
fn main() -> Result<()> {
@ -150,10 +162,18 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::v1_5();
let model = Model::new(&config, vb)?;
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
let config = Config::v1_5();
let model = QMixFormer::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::v1_5();
let model = MixFormer::new(&config, vb)?;
(Model::MixFormer(model), device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(