mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
add quantized rwkv v5 model (#1743)
* and quantized rwkv v5 model * Integrate the quantized rwkv model in the initial example. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -7,13 +7,28 @@ extern crate accelerate_src;
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
|
||||
use candle_transformers::models::quantized_rwkv_v5::Model as Q;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
enum Model {
|
||||
M(M),
|
||||
Q(Q),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M(m) => m.forward(xs, state),
|
||||
Self::Q(m) => m.forward(xs, state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
@ -176,6 +191,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -236,7 +254,16 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
if args.quantized {
|
||||
let file = match args.which {
|
||||
Which::World1b5 => "world1b5-q4k.gguf",
|
||||
Which::World3b => "world3b-q4k.gguf",
|
||||
Which::Eagle7b => "eagle7b-q4k.gguf",
|
||||
};
|
||||
vec![api.model("lmz/candle-rwkv".to_string()).get(file)?]
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
@ -245,8 +272,15 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
let model = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
Model::Q(Q::new(&config, vb)?)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
Model::M(M::new(&config, vb)?)
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
|
Reference in New Issue
Block a user