mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add the quantized mpt model. (#1123)
* Add the quantized mpt model. * Support the quantized model for replit-code.
This commit is contained in:
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mpt::{Config, Model};
|
||||
use candle_transformers::models::mpt::{Config, Model as M};
|
||||
use candle_transformers::models::quantized_mpt::Model as Q;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -15,6 +16,20 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
M(M),
|
||||
Q(Q),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M(model) => model.forward(xs),
|
||||
Self::Q(model) => model.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -148,6 +163,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
|
||||
@ -206,16 +224,29 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => repo.get("model.safetensors")?,
|
||||
None => {
|
||||
if args.quantized {
|
||||
repo.get("model-replit-code-v1_5-q4k.gguf")?
|
||||
} else {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = Config::replit_code_v1_5_3b();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb.pp("transformer"))?;
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
|
||||
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
|
||||
(model, Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
|
||||
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
|
||||
(model, device)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
|
Reference in New Issue
Block a user