add models of rwkv v6 and quantized rwkv v6 (#1781)

* add models of rwkv v6 and quantized rwkv v6

* fix ci clippy fail
This commit is contained in:
Jack Shih
2024-03-01 15:37:56 +08:00
committed by GitHub
parent 2c95b7394a
commit b485e4b6ee
4 changed files with 670 additions and 15 deletions

View File

@ -7,8 +7,10 @@ extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::quantized_rwkv_v5::Model as Q;
use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer};
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
use candle_transformers::models::rwkv_v6::Model as M6;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
@ -16,15 +18,19 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
enum Model {
M(M),
Q(Q),
M5(M5),
Q5(Q5),
M6(M6),
Q6(Q6),
}
impl Model {
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
match self {
Self::M(m) => m.forward(xs, state),
Self::Q(m) => m.forward(xs, state),
Self::M5(m) => m.forward(xs, state),
Self::Q5(m) => m.forward(xs, state),
Self::M6(m) => m.forward(xs, state),
Self::Q6(m) => m.forward(xs, state),
}
}
}
@ -118,6 +124,7 @@ enum Which {
Eagle7b,
World1b5,
World3b,
World6_1b6,
}
impl std::fmt::Display for Which {
@ -132,6 +139,7 @@ impl Which {
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
Self::World3b => "RWKV/rwkv-5-world-3b",
Self::World6_1b6 => "paperfun/rwkv",
}
}
@ -139,6 +147,7 @@ impl Which {
match self {
Self::Eagle7b => "refs/pr/1",
Self::World1b5 | Self::World3b => "refs/pr/2",
Self::World6_1b6 => "main",
}
}
}
@ -255,14 +264,25 @@ fn main() -> Result<()> {
.collect::<Vec<_>>(),
None => {
if args.quantized {
let file = match args.which {
Which::World1b5 => "world1b5-q4k.gguf",
Which::World3b => "world3b-q4k.gguf",
Which::Eagle7b => "eagle7b-q4k.gguf",
};
vec![api.model("lmz/candle-rwkv".to_string()).get(file)?]
vec![match args.which {
Which::World1b5 => api
.model("lmz/candle-rwkv".to_string())
.get("world1b5-q4k.gguf")?,
Which::World3b => api
.model("lmz/candle-rwkv".to_string())
.get("world3b-q4k.gguf")?,
Which::Eagle7b => api
.model("lmz/candle-rwkv".to_string())
.get("eagle7b-q4k.gguf")?,
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
}]
} else {
vec![repo.get("model.safetensors")?]
vec![match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => {
repo.get("model.safetensors")?
}
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
}]
}
}
};
@ -276,10 +296,16 @@ fn main() -> Result<()> {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
Model::Q(Q::new(&config, vb)?)
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
}
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
Model::M(M::new(&config, vb)?)
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
}
};
println!("loaded the model in {:?}", start.elapsed());