Quantized version of StableLM. (#1058)

* Quantized version of StableLM.

* Adapt the stable-lm example to support quantizsed.

* Use some separate hub repo.

* Another repo name tweak.
This commit is contained in:
Laurent Mazare
2023-10-08 15:42:38 +01:00
committed by GitHub
parent 783735cf22
commit 59ab6d7832
4 changed files with 331 additions and 14 deletions

View File

@ -7,7 +7,8 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::stable_lm::{Config, Model};
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
@ -16,6 +17,11 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
StableLM(StableLM),
Quantized(QStableLM),
}
struct TextGeneration {
model: Model,
device: Device,
@ -76,7 +82,10 @@ impl TextGeneration {
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = match &mut self.model {
Model::StableLM(m) => m.forward(&input, start_pos)?,
Model::Quantized(m) => m.forward(&input, start_pos)?,
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
@ -146,7 +155,7 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "stabilityai/stablelm-3b-4e1t")]
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
model_id: String,
#[arg(long, default_value = "main")]
@ -213,7 +222,11 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
vec![repo.get("model.safetensors")?]
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
@ -221,7 +234,12 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let (model, device) = {
let (model, device) = if args.quantized {
let filename = &filenames[0];
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
@ -229,8 +247,8 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
(model, device)
let model = StableLM::new(&config, vb)?;
(Model::StableLM(model), device)
};
println!("loaded the model in {:?}", start.elapsed());