mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
add models of rwkv v6 and quantized rwkv v6 (#1781)
* add models of rwkv v6 and quantized rwkv v6 * fix ci clippy fail
This commit is contained in:
@ -7,8 +7,10 @@ extern crate accelerate_src;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::quantized_rwkv_v5::Model as Q;
|
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
|
||||||
use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer};
|
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
|
||||||
|
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
|
||||||
|
use candle_transformers::models::rwkv_v6::Model as M6;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -16,15 +18,19 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
M(M),
|
M5(M5),
|
||||||
Q(Q),
|
Q5(Q5),
|
||||||
|
M6(M6),
|
||||||
|
Q6(Q6),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
|
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
Self::M(m) => m.forward(xs, state),
|
Self::M5(m) => m.forward(xs, state),
|
||||||
Self::Q(m) => m.forward(xs, state),
|
Self::Q5(m) => m.forward(xs, state),
|
||||||
|
Self::M6(m) => m.forward(xs, state),
|
||||||
|
Self::Q6(m) => m.forward(xs, state),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -118,6 +124,7 @@ enum Which {
|
|||||||
Eagle7b,
|
Eagle7b,
|
||||||
World1b5,
|
World1b5,
|
||||||
World3b,
|
World3b,
|
||||||
|
World6_1b6,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for Which {
|
impl std::fmt::Display for Which {
|
||||||
@ -132,6 +139,7 @@ impl Which {
|
|||||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||||
|
Self::World6_1b6 => "paperfun/rwkv",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,6 +147,7 @@ impl Which {
|
|||||||
match self {
|
match self {
|
||||||
Self::Eagle7b => "refs/pr/1",
|
Self::Eagle7b => "refs/pr/1",
|
||||||
Self::World1b5 | Self::World3b => "refs/pr/2",
|
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||||
|
Self::World6_1b6 => "main",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -255,14 +264,25 @@ fn main() -> Result<()> {
|
|||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
let file = match args.which {
|
vec![match args.which {
|
||||||
Which::World1b5 => "world1b5-q4k.gguf",
|
Which::World1b5 => api
|
||||||
Which::World3b => "world3b-q4k.gguf",
|
.model("lmz/candle-rwkv".to_string())
|
||||||
Which::Eagle7b => "eagle7b-q4k.gguf",
|
.get("world1b5-q4k.gguf")?,
|
||||||
};
|
Which::World3b => api
|
||||||
vec![api.model("lmz/candle-rwkv".to_string()).get(file)?]
|
.model("lmz/candle-rwkv".to_string())
|
||||||
|
.get("world3b-q4k.gguf")?,
|
||||||
|
Which::Eagle7b => api
|
||||||
|
.model("lmz/candle-rwkv".to_string())
|
||||||
|
.get("eagle7b-q4k.gguf")?,
|
||||||
|
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
|
||||||
|
}]
|
||||||
} else {
|
} else {
|
||||||
vec![repo.get("model.safetensors")?]
|
vec![match args.which {
|
||||||
|
Which::World1b5 | Which::World3b | Which::Eagle7b => {
|
||||||
|
repo.get("model.safetensors")?
|
||||||
|
}
|
||||||
|
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
|
||||||
|
}]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -276,10 +296,16 @@ fn main() -> Result<()> {
|
|||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
let vb =
|
let vb =
|
||||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||||
Model::Q(Q::new(&config, vb)?)
|
match args.which {
|
||||||
|
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
|
||||||
|
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
Model::M(M::new(&config, vb)?)
|
match args.which {
|
||||||
|
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
|
||||||
|
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
@ -32,12 +32,14 @@ pub mod quantized_mistral;
|
|||||||
pub mod quantized_mixformer;
|
pub mod quantized_mixformer;
|
||||||
pub mod quantized_mpt;
|
pub mod quantized_mpt;
|
||||||
pub mod quantized_rwkv_v5;
|
pub mod quantized_rwkv_v5;
|
||||||
|
pub mod quantized_rwkv_v6;
|
||||||
pub mod quantized_stable_lm;
|
pub mod quantized_stable_lm;
|
||||||
pub mod quantized_t5;
|
pub mod quantized_t5;
|
||||||
pub mod qwen2;
|
pub mod qwen2;
|
||||||
pub mod repvgg;
|
pub mod repvgg;
|
||||||
pub mod resnet;
|
pub mod resnet;
|
||||||
pub mod rwkv_v5;
|
pub mod rwkv_v5;
|
||||||
|
pub mod rwkv_v6;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
|
332
candle-transformers/src/models/quantized_rwkv_v6.rs
Normal file
332
candle-transformers/src/models/quantized_rwkv_v6.rs
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
use crate::{
|
||||||
|
quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
|
||||||
|
quantized_var_builder::VarBuilder,
|
||||||
|
};
|
||||||
|
use candle::{IndexOp, Result, Tensor};
|
||||||
|
use candle_nn::{GroupNorm, LayerNorm, Module};
|
||||||
|
|
||||||
|
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct SelfAttention {
|
||||||
|
key: Linear,
|
||||||
|
receptance: Linear,
|
||||||
|
value: Linear,
|
||||||
|
gate: Linear,
|
||||||
|
output: Linear,
|
||||||
|
ln_x: candle_nn::GroupNorm,
|
||||||
|
time_mix_x: Tensor,
|
||||||
|
time_mix_w: Tensor,
|
||||||
|
time_mix_key: Tensor,
|
||||||
|
time_mix_value: Tensor,
|
||||||
|
time_mix_receptance: Tensor,
|
||||||
|
time_decay: Tensor,
|
||||||
|
time_faaaa: Tensor,
|
||||||
|
time_mix_gate: Tensor,
|
||||||
|
time_decay_w1: Tensor,
|
||||||
|
time_decay_w2: Tensor,
|
||||||
|
time_mix_w1: Tensor,
|
||||||
|
time_mix_w2: Tensor,
|
||||||
|
layer_id: usize,
|
||||||
|
n_attn_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SelfAttention {
|
||||||
|
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_size = cfg.hidden_size;
|
||||||
|
let attn_hidden_size = cfg.attention_hidden_size;
|
||||||
|
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||||
|
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||||
|
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||||
|
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||||
|
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||||
|
|
||||||
|
let vb_x = vb.pp("ln_x");
|
||||||
|
let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
|
||||||
|
let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
|
||||||
|
|
||||||
|
let ln_x = GroupNorm::new(
|
||||||
|
ln_x_weight,
|
||||||
|
ln_x_bias,
|
||||||
|
hidden_size,
|
||||||
|
hidden_size / cfg.head_size,
|
||||||
|
1e-5,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let time_mix_x = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_x")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_w = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_w")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_key = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_key")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_value = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_value")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_receptance = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||||
|
let time_decay = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_decay")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_faaaa = vb
|
||||||
|
.get((n_attn_heads, cfg.head_size), "time_faaaa")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_gate = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_gate")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_decay_w1 = vb
|
||||||
|
.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_decay_w2 = vb
|
||||||
|
.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_w1 = vb
|
||||||
|
.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_w2 = vb
|
||||||
|
.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
Ok(Self {
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
receptance,
|
||||||
|
gate,
|
||||||
|
output,
|
||||||
|
ln_x,
|
||||||
|
time_mix_x,
|
||||||
|
time_mix_w,
|
||||||
|
time_mix_key,
|
||||||
|
time_mix_value,
|
||||||
|
time_mix_receptance,
|
||||||
|
time_decay,
|
||||||
|
time_faaaa,
|
||||||
|
time_mix_gate,
|
||||||
|
time_decay_w1,
|
||||||
|
time_decay_w2,
|
||||||
|
time_mix_w1,
|
||||||
|
time_mix_w2,
|
||||||
|
layer_id,
|
||||||
|
n_attn_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let h = self.n_attn_heads;
|
||||||
|
let (b, t, s) = xs.dims3()?;
|
||||||
|
let s = s / h;
|
||||||
|
let (receptance, key, value, gate, w) = {
|
||||||
|
// extract key-value
|
||||||
|
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||||
|
let shifted = if shifted.rank() == 2 {
|
||||||
|
shifted.unsqueeze(1)?
|
||||||
|
} else {
|
||||||
|
shifted
|
||||||
|
};
|
||||||
|
|
||||||
|
let sx = (&shifted - xs)?;
|
||||||
|
let xxx = (xs + &sx * &self.time_mix_x)?;
|
||||||
|
let xxx = xxx
|
||||||
|
.broadcast_matmul(&self.time_mix_w1)?
|
||||||
|
.tanh()?
|
||||||
|
.reshape((b * t, 5, ()))?
|
||||||
|
.transpose(0, 1)?;
|
||||||
|
|
||||||
|
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
|
||||||
|
|
||||||
|
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
|
||||||
|
|
||||||
|
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
|
||||||
|
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
|
||||||
|
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
|
||||||
|
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
|
||||||
|
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
|
||||||
|
|
||||||
|
let w = (&self.time_decay
|
||||||
|
+ xw.broadcast_matmul(&self.time_decay_w1)?
|
||||||
|
.tanh()?
|
||||||
|
.broadcast_matmul(&self.time_decay_w2)?)?
|
||||||
|
.reshape(((), 1, 1))?
|
||||||
|
.reshape((self.n_attn_heads, (), 1))?;
|
||||||
|
|
||||||
|
let key = self.key.forward(&xk)?;
|
||||||
|
let value = self.value.forward(&xv)?;
|
||||||
|
let receptance = self.receptance.forward(&xr)?;
|
||||||
|
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
|
||||||
|
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||||
|
(receptance, key, value, gate, w)
|
||||||
|
};
|
||||||
|
|
||||||
|
// linear attention
|
||||||
|
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||||
|
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||||
|
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||||
|
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let w = w.exp()?.neg()?.exp()?;
|
||||||
|
|
||||||
|
let time_faaaa =
|
||||||
|
self.time_faaaa
|
||||||
|
.reshape(((), 1, 1))?
|
||||||
|
.reshape((self.n_attn_heads, (), 1))?;
|
||||||
|
|
||||||
|
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||||
|
for t_ in 0..t {
|
||||||
|
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||||
|
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||||
|
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||||
|
let at = kt.matmul(&vt)?;
|
||||||
|
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||||
|
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||||
|
state_ = (&at + w.broadcast_mul(&state_))?;
|
||||||
|
out.push(out_)
|
||||||
|
}
|
||||||
|
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||||
|
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||||
|
let out = (out * gate)?.apply(&self.output)?;
|
||||||
|
state.per_layer[self.layer_id].linear_attention = state_;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct FeedForward {
|
||||||
|
time_mix_key: Tensor,
|
||||||
|
time_mix_receptance: Tensor,
|
||||||
|
key: Linear,
|
||||||
|
receptance: Linear,
|
||||||
|
value: Linear,
|
||||||
|
layer_id: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FeedForward {
|
||||||
|
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let int_size = cfg
|
||||||
|
.intermediate_size
|
||||||
|
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||||
|
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||||
|
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||||
|
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||||
|
let time_mix_key = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_key")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
let time_mix_receptance = vb
|
||||||
|
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
|
||||||
|
.dequantize(vb.device())?;
|
||||||
|
Ok(Self {
|
||||||
|
key,
|
||||||
|
receptance,
|
||||||
|
value,
|
||||||
|
time_mix_key,
|
||||||
|
time_mix_receptance,
|
||||||
|
layer_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let shifted = state.per_layer[self.layer_id]
|
||||||
|
.feed_forward
|
||||||
|
.broadcast_sub(xs)?;
|
||||||
|
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
|
||||||
|
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
|
||||||
|
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||||
|
let value = key.apply(&self.value)?;
|
||||||
|
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||||
|
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||||
|
let xs = (receptance * value)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Block {
|
||||||
|
pre_ln: Option<LayerNorm>,
|
||||||
|
ln1: LayerNorm,
|
||||||
|
ln2: LayerNorm,
|
||||||
|
attention: SelfAttention,
|
||||||
|
feed_forward: FeedForward,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||||
|
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||||
|
let pre_ln = if layer_id == 0 {
|
||||||
|
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||||
|
Some(ln)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||||
|
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||||
|
Ok(Self {
|
||||||
|
pre_ln,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
attention,
|
||||||
|
feed_forward,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let xs = match self.pre_ln.as_ref() {
|
||||||
|
None => xs.clone(),
|
||||||
|
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||||
|
};
|
||||||
|
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||||
|
let xs = (xs + attention)?;
|
||||||
|
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||||
|
let xs = (xs + feed_forward)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embeddings: Embedding,
|
||||||
|
blocks: Vec<Block>,
|
||||||
|
ln_out: LayerNorm,
|
||||||
|
head: Linear,
|
||||||
|
rescale_every: usize,
|
||||||
|
layers_are_rescaled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("rwkv");
|
||||||
|
let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||||
|
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_b = vb_m.pp("blocks");
|
||||||
|
for block_index in 0..cfg.num_hidden_layers {
|
||||||
|
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||||
|
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
blocks,
|
||||||
|
ln_out,
|
||||||
|
head,
|
||||||
|
rescale_every: cfg.rescale_every,
|
||||||
|
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let (_b_size, _seq_len) = xs.dims2()?;
|
||||||
|
let mut xs = xs.apply(&self.embeddings)?;
|
||||||
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
|
xs = block.forward(&xs, state)?;
|
||||||
|
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||||
|
xs = (xs / 2.)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||||
|
state.pos += 1;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
295
candle-transformers/src/models/rwkv_v6.rs
Normal file
295
candle-transformers/src/models/rwkv_v6.rs
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||||
|
use candle::{IndexOp, Result, Tensor};
|
||||||
|
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||||
|
|
||||||
|
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct SelfAttention {
|
||||||
|
key: Linear,
|
||||||
|
receptance: Linear,
|
||||||
|
value: Linear,
|
||||||
|
gate: Linear,
|
||||||
|
output: Linear,
|
||||||
|
ln_x: candle_nn::GroupNorm,
|
||||||
|
time_mix_x: Tensor,
|
||||||
|
time_mix_w: Tensor,
|
||||||
|
time_mix_key: Tensor,
|
||||||
|
time_mix_value: Tensor,
|
||||||
|
time_mix_receptance: Tensor,
|
||||||
|
time_decay: Tensor,
|
||||||
|
time_faaaa: Tensor,
|
||||||
|
time_mix_gate: Tensor,
|
||||||
|
time_decay_w1: Tensor,
|
||||||
|
time_decay_w2: Tensor,
|
||||||
|
time_mix_w1: Tensor,
|
||||||
|
time_mix_w2: Tensor,
|
||||||
|
layer_id: usize,
|
||||||
|
n_attn_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SelfAttention {
|
||||||
|
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_size = cfg.hidden_size;
|
||||||
|
let attn_hidden_size = cfg.attention_hidden_size;
|
||||||
|
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||||
|
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||||
|
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||||
|
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||||
|
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||||
|
let ln_x = candle_nn::group_norm(
|
||||||
|
hidden_size / cfg.head_size,
|
||||||
|
hidden_size,
|
||||||
|
1e-5,
|
||||||
|
vb.pp("ln_x"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let time_mix_x = vb.get((1, 1, cfg.hidden_size), "time_mix_x")?;
|
||||||
|
let time_mix_w = vb.get((1, 1, cfg.hidden_size), "time_mix_w")?;
|
||||||
|
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||||
|
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
|
||||||
|
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||||
|
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||||
|
let time_decay = vb.get((1, 1, cfg.hidden_size), "time_decay")?;
|
||||||
|
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
|
||||||
|
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
|
||||||
|
let time_decay_w1 = vb.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?;
|
||||||
|
let time_decay_w2 = vb.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?;
|
||||||
|
let time_mix_w1 = vb.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?;
|
||||||
|
let time_mix_w2 = vb.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?;
|
||||||
|
Ok(Self {
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
receptance,
|
||||||
|
gate,
|
||||||
|
output,
|
||||||
|
ln_x,
|
||||||
|
time_mix_x,
|
||||||
|
time_mix_w,
|
||||||
|
time_mix_key,
|
||||||
|
time_mix_value,
|
||||||
|
time_mix_receptance,
|
||||||
|
time_decay,
|
||||||
|
time_faaaa,
|
||||||
|
time_mix_gate,
|
||||||
|
time_decay_w1,
|
||||||
|
time_decay_w2,
|
||||||
|
time_mix_w1,
|
||||||
|
time_mix_w2,
|
||||||
|
layer_id,
|
||||||
|
n_attn_heads,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let h = self.n_attn_heads;
|
||||||
|
let (b, t, s) = xs.dims3()?;
|
||||||
|
let s = s / h;
|
||||||
|
let (receptance, key, value, gate, w) = {
|
||||||
|
// extract key-value
|
||||||
|
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||||
|
let shifted = if shifted.rank() == 2 {
|
||||||
|
shifted.unsqueeze(1)?
|
||||||
|
} else {
|
||||||
|
shifted
|
||||||
|
};
|
||||||
|
|
||||||
|
let sx = (&shifted - xs)?;
|
||||||
|
let xxx = (xs + &sx * &self.time_mix_x)?;
|
||||||
|
let xxx = xxx
|
||||||
|
.broadcast_matmul(&self.time_mix_w1)?
|
||||||
|
.tanh()?
|
||||||
|
.reshape((b * t, 5, ()))?
|
||||||
|
.transpose(0, 1)?;
|
||||||
|
|
||||||
|
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
|
||||||
|
|
||||||
|
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
|
||||||
|
|
||||||
|
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
|
||||||
|
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
|
||||||
|
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
|
||||||
|
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
|
||||||
|
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
|
||||||
|
|
||||||
|
let w = (&self.time_decay
|
||||||
|
+ xw.broadcast_matmul(&self.time_decay_w1)?
|
||||||
|
.tanh()?
|
||||||
|
.broadcast_matmul(&self.time_decay_w2)?)?
|
||||||
|
.reshape(((), 1, 1))?
|
||||||
|
.reshape((self.n_attn_heads, (), 1))?;
|
||||||
|
|
||||||
|
let key = self.key.forward(&xk)?;
|
||||||
|
let value = self.value.forward(&xv)?;
|
||||||
|
let receptance = self.receptance.forward(&xr)?;
|
||||||
|
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
|
||||||
|
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||||
|
(receptance, key, value, gate, w)
|
||||||
|
};
|
||||||
|
|
||||||
|
// linear attention
|
||||||
|
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||||
|
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||||
|
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||||
|
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let w = w.exp()?.neg()?.exp()?;
|
||||||
|
|
||||||
|
let time_faaaa =
|
||||||
|
self.time_faaaa
|
||||||
|
.reshape(((), 1, 1))?
|
||||||
|
.reshape((self.n_attn_heads, (), 1))?;
|
||||||
|
|
||||||
|
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||||
|
for t_ in 0..t {
|
||||||
|
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||||
|
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||||
|
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||||
|
let at = kt.matmul(&vt)?;
|
||||||
|
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||||
|
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||||
|
state_ = (&at + w.broadcast_mul(&state_))?;
|
||||||
|
out.push(out_)
|
||||||
|
}
|
||||||
|
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||||
|
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||||
|
let out = (out * gate)?.apply(&self.output)?;
|
||||||
|
state.per_layer[self.layer_id].linear_attention = state_;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct FeedForward {
|
||||||
|
time_mix_key: Tensor,
|
||||||
|
time_mix_receptance: Tensor,
|
||||||
|
key: Linear,
|
||||||
|
receptance: Linear,
|
||||||
|
value: Linear,
|
||||||
|
layer_id: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FeedForward {
|
||||||
|
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let int_size = cfg
|
||||||
|
.intermediate_size
|
||||||
|
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||||
|
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||||
|
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||||
|
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||||
|
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||||
|
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||||
|
Ok(Self {
|
||||||
|
key,
|
||||||
|
receptance,
|
||||||
|
value,
|
||||||
|
time_mix_key,
|
||||||
|
time_mix_receptance,
|
||||||
|
layer_id,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let shifted = state.per_layer[self.layer_id]
|
||||||
|
.feed_forward
|
||||||
|
.broadcast_sub(xs)?;
|
||||||
|
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
|
||||||
|
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
|
||||||
|
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||||
|
let value = key.apply(&self.value)?;
|
||||||
|
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||||
|
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||||
|
let xs = (receptance * value)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Block {
|
||||||
|
pre_ln: Option<LayerNorm>,
|
||||||
|
ln1: LayerNorm,
|
||||||
|
ln2: LayerNorm,
|
||||||
|
attention: SelfAttention,
|
||||||
|
feed_forward: FeedForward,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||||
|
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||||
|
let pre_ln = if layer_id == 0 {
|
||||||
|
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||||
|
Some(ln)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||||
|
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||||
|
Ok(Self {
|
||||||
|
pre_ln,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
attention,
|
||||||
|
feed_forward,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let xs = match self.pre_ln.as_ref() {
|
||||||
|
None => xs.clone(),
|
||||||
|
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||||
|
};
|
||||||
|
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||||
|
let xs = (xs + attention)?;
|
||||||
|
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||||
|
let xs = (xs + feed_forward)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embeddings: Embedding,
|
||||||
|
blocks: Vec<Block>,
|
||||||
|
ln_out: LayerNorm,
|
||||||
|
head: Linear,
|
||||||
|
rescale_every: usize,
|
||||||
|
layers_are_rescaled: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("rwkv");
|
||||||
|
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||||
|
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_b = vb_m.pp("blocks");
|
||||||
|
for block_index in 0..cfg.num_hidden_layers {
|
||||||
|
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||||
|
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
blocks,
|
||||||
|
ln_out,
|
||||||
|
head,
|
||||||
|
rescale_every: cfg.rescale_every,
|
||||||
|
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||||
|
let (_b_size, _seq_len) = xs.dims2()?;
|
||||||
|
let mut xs = xs.apply(&self.embeddings)?;
|
||||||
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
|
xs = block.forward(&xs, state)?;
|
||||||
|
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||||
|
xs = (xs / 2.)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||||
|
state.pos += 1;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user