add quantized rwkv v5 model (#1743)

* and quantized rwkv v5 model

* Integrate the quantized rwkv model in the initial example.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Jack Shih
2024-02-26 04:43:40 +08:00
committed by GitHub
parent 1a6043af51
commit 918136ba46
4 changed files with 326 additions and 6 deletions

View File

@ -7,13 +7,28 @@ extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
use candle_transformers::models::quantized_rwkv_v5::Model as Q;
use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
enum Model {
M(M),
Q(Q),
}
impl Model {
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
match self {
Self::M(m) => m.forward(xs, state),
Self::Q(m) => m.forward(xs, state),
}
}
}
struct TextGeneration {
model: Model,
config: Config,
@ -176,6 +191,9 @@ struct Args {
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
quantized: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
@ -236,7 +254,16 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
if args.quantized {
let file = match args.which {
Which::World1b5 => "world1b5-q4k.gguf",
Which::World3b => "world3b-q4k.gguf",
Which::Eagle7b => "eagle7b-q4k.gguf",
};
vec![api.model("lmz/candle-rwkv".to_string()).get(file)?]
} else {
vec![repo.get("model.safetensors")?]
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
@ -245,8 +272,15 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
let model = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
Model::Q(Q::new(&config, vb)?)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
Model::M(M::new(&config, vb)?)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(