Add the quantized mpt model. (#1123)

* Add the quantized mpt model.

* Support the quantized model for replit-code.
This commit is contained in:
Laurent Mazare
2023-10-18 16:29:38 +01:00
committed by GitHub
parent cb034506cd
commit 86e7d539d2
5 changed files with 247 additions and 9 deletions

View File

@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mpt::{Config, Model};
use candle_transformers::models::mpt::{Config, Model as M};
use candle_transformers::models::quantized_mpt::Model as Q;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@ -15,6 +16,20 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
M(M),
Q(Q),
}
impl Model {
fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::M(model) => model.forward(xs),
Self::Q(model) => model.forward(xs),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
@ -148,6 +163,9 @@ struct Args {
#[arg(long)]
revision: Option<String>,
#[arg(long)]
quantized: bool,
#[arg(long)]
weight_file: Option<String>,
@ -206,16 +224,29 @@ fn main() -> Result<()> {
};
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
None => repo.get("model.safetensors")?,
None => {
if args.quantized {
repo.get("model-replit-code-v1_5-q4k.gguf")?
} else {
repo.get("model.safetensors")?
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::replit_code_v1_5_3b();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::new(&config, vb.pp("transformer"))?;
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(