mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
add models of rwkv v6 and quantized rwkv v6 (#1781)
* add models of rwkv v6 and quantized rwkv v6 * fix ci clippy fail
This commit is contained in:
@ -7,8 +7,10 @@ extern crate accelerate_src;
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::quantized_rwkv_v5::Model as Q;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer};
|
||||
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
|
||||
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
|
||||
use candle_transformers::models::rwkv_v6::Model as M6;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -16,15 +18,19 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
enum Model {
|
||||
M(M),
|
||||
Q(Q),
|
||||
M5(M5),
|
||||
Q5(Q5),
|
||||
M6(M6),
|
||||
Q6(Q6),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M(m) => m.forward(xs, state),
|
||||
Self::Q(m) => m.forward(xs, state),
|
||||
Self::M5(m) => m.forward(xs, state),
|
||||
Self::Q5(m) => m.forward(xs, state),
|
||||
Self::M6(m) => m.forward(xs, state),
|
||||
Self::Q6(m) => m.forward(xs, state),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -118,6 +124,7 @@ enum Which {
|
||||
Eagle7b,
|
||||
World1b5,
|
||||
World3b,
|
||||
World6_1b6,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
@ -132,6 +139,7 @@ impl Which {
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
Self::World6_1b6 => "paperfun/rwkv",
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,6 +147,7 @@ impl Which {
|
||||
match self {
|
||||
Self::Eagle7b => "refs/pr/1",
|
||||
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||
Self::World6_1b6 => "main",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -255,14 +264,25 @@ fn main() -> Result<()> {
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
let file = match args.which {
|
||||
Which::World1b5 => "world1b5-q4k.gguf",
|
||||
Which::World3b => "world3b-q4k.gguf",
|
||||
Which::Eagle7b => "eagle7b-q4k.gguf",
|
||||
};
|
||||
vec![api.model("lmz/candle-rwkv".to_string()).get(file)?]
|
||||
vec![match args.which {
|
||||
Which::World1b5 => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world1b5-q4k.gguf")?,
|
||||
Which::World3b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world3b-q4k.gguf")?,
|
||||
Which::Eagle7b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("eagle7b-q4k.gguf")?,
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
|
||||
}]
|
||||
} else {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
vec![match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
|
||||
}]
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -276,10 +296,16 @@ fn main() -> Result<()> {
|
||||
let filename = &filenames[0];
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
Model::Q(Q::new(&config, vb)?)
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
|
||||
}
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
Model::M(M::new(&config, vb)?)
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
|
||||
}
|
||||
};
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
Reference in New Issue
Block a user