mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Add a quantized version of recurrent-gemma. (#2054)
* Add a quantized version of recurrent-gemma. * Share the rglru part. * Get the quantized gemma model to work.
This commit is contained in:
@ -7,7 +7,8 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::recurrent_gemma::{Config, Model};
|
||||
use candle_transformers::models::quantized_recurrent_gemma::Model as QModel;
|
||||
use candle_transformers::models::recurrent_gemma::{Config, Model as BModel};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -16,6 +17,20 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
B(BModel),
|
||||
Q(QModel),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::B(m) => m.forward(xs, pos),
|
||||
Self::Q(m) => m.forward(xs, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
@ -195,6 +210,9 @@ struct Args {
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -250,7 +268,18 @@ fn main() -> Result<()> {
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
None => {
|
||||
if args.quantized {
|
||||
let filename = match args.which {
|
||||
Which::Base2B => "recurrent-gemma-2b-q4k.gguf",
|
||||
Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf",
|
||||
};
|
||||
let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?;
|
||||
vec![filename]
|
||||
} else {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
@ -263,8 +292,16 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb.pp("model"))?;
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
Model::Q(QModel::new(&config, vb.pp("model"))?)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
Model::B(BModel::new(&config, vb.pp("model"))?)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
Reference in New Issue
Block a user