diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs index 771baa03..a2717170 100644 --- a/candle-examples/examples/rwkv/main.rs +++ b/candle-examples/examples/rwkv/main.rs @@ -7,8 +7,10 @@ extern crate accelerate_src; use anyhow::Result; use clap::{Parser, ValueEnum}; -use candle_transformers::models::quantized_rwkv_v5::Model as Q; -use candle_transformers::models::rwkv_v5::{Config, Model as M, State, Tokenizer}; +use candle_transformers::models::quantized_rwkv_v5::Model as Q5; +use candle_transformers::models::quantized_rwkv_v6::Model as Q6; +use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer}; +use candle_transformers::models::rwkv_v6::Model as M6; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; @@ -16,15 +18,19 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; enum Model { - M(M), - Q(Q), + M5(M5), + Q5(Q5), + M6(M6), + Q6(Q6), } impl Model { fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result { match self { - Self::M(m) => m.forward(xs, state), - Self::Q(m) => m.forward(xs, state), + Self::M5(m) => m.forward(xs, state), + Self::Q5(m) => m.forward(xs, state), + Self::M6(m) => m.forward(xs, state), + Self::Q6(m) => m.forward(xs, state), } } } @@ -118,6 +124,7 @@ enum Which { Eagle7b, World1b5, World3b, + World6_1b6, } impl std::fmt::Display for Which { @@ -132,6 +139,7 @@ impl Which { Self::Eagle7b => "RWKV/HF_v5-Eagle-7B", Self::World1b5 => "RWKV/rwkv-5-world-1b5", Self::World3b => "RWKV/rwkv-5-world-3b", + Self::World6_1b6 => "paperfun/rwkv", } } @@ -139,6 +147,7 @@ impl Which { match self { Self::Eagle7b => "refs/pr/1", Self::World1b5 | Self::World3b => "refs/pr/2", + Self::World6_1b6 => "main", } } } @@ -255,14 +264,25 @@ fn main() -> Result<()> { .collect::>(), None => { if args.quantized { - let file = match args.which { - Which::World1b5 => "world1b5-q4k.gguf", - Which::World3b => "world3b-q4k.gguf", - Which::Eagle7b => "eagle7b-q4k.gguf", - }; - vec![api.model("lmz/candle-rwkv".to_string()).get(file)?] + vec![match args.which { + Which::World1b5 => api + .model("lmz/candle-rwkv".to_string()) + .get("world1b5-q4k.gguf")?, + Which::World3b => api + .model("lmz/candle-rwkv".to_string()) + .get("world3b-q4k.gguf")?, + Which::Eagle7b => api + .model("lmz/candle-rwkv".to_string()) + .get("eagle7b-q4k.gguf")?, + Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?, + }] } else { - vec![repo.get("model.safetensors")?] + vec![match args.which { + Which::World1b5 | Which::World3b | Which::Eagle7b => { + repo.get("model.safetensors")? + } + Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?, + }] } } }; @@ -276,10 +296,16 @@ fn main() -> Result<()> { let filename = &filenames[0]; let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?; - Model::Q(Q::new(&config, vb)?) + match args.which { + Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?), + Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?), + } } else { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - Model::M(M::new(&config, vb)?) + match args.which { + Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?), + Which::World6_1b6 => Model::M6(M6::new(&config, vb)?), + } }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 22527daa..6615c0ac 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -32,12 +32,14 @@ pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_mpt; pub mod quantized_rwkv_v5; +pub mod quantized_rwkv_v6; pub mod quantized_stable_lm; pub mod quantized_t5; pub mod qwen2; pub mod repvgg; pub mod resnet; pub mod rwkv_v5; +pub mod rwkv_v6; pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/quantized_rwkv_v6.rs b/candle-transformers/src/models/quantized_rwkv_v6.rs new file mode 100644 index 00000000..81150c3e --- /dev/null +++ b/candle-transformers/src/models/quantized_rwkv_v6.rs @@ -0,0 +1,332 @@ +use crate::{ + quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear}, + quantized_var_builder::VarBuilder, +}; +use candle::{IndexOp, Result, Tensor}; +use candle_nn::{GroupNorm, LayerNorm, Module}; + +pub use crate::models::rwkv_v5::{Config, State, Tokenizer}; + +#[derive(Debug, Clone)] +struct SelfAttention { + key: Linear, + receptance: Linear, + value: Linear, + gate: Linear, + output: Linear, + ln_x: candle_nn::GroupNorm, + time_mix_x: Tensor, + time_mix_w: Tensor, + time_mix_key: Tensor, + time_mix_value: Tensor, + time_mix_receptance: Tensor, + time_decay: Tensor, + time_faaaa: Tensor, + time_mix_gate: Tensor, + time_decay_w1: Tensor, + time_decay_w2: Tensor, + time_mix_w1: Tensor, + time_mix_w2: Tensor, + layer_id: usize, + n_attn_heads: usize, +} + +impl SelfAttention { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size; + let attn_hidden_size = cfg.attention_hidden_size; + let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?; + let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?; + let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?; + let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?; + let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?; + + let vb_x = vb.pp("ln_x"); + let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?; + let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?; + + let ln_x = GroupNorm::new( + ln_x_weight, + ln_x_bias, + hidden_size, + hidden_size / cfg.head_size, + 1e-5, + )?; + + let time_mix_x = vb + .get((1, 1, cfg.hidden_size), "time_mix_x")? + .dequantize(vb.device())?; + let time_mix_w = vb + .get((1, 1, cfg.hidden_size), "time_mix_w")? + .dequantize(vb.device())?; + let time_mix_key = vb + .get((1, 1, cfg.hidden_size), "time_mix_key")? + .dequantize(vb.device())?; + let time_mix_value = vb + .get((1, 1, cfg.hidden_size), "time_mix_value")? + .dequantize(vb.device())?; + let time_mix_receptance = vb + .get((1, 1, cfg.hidden_size), "time_mix_receptance")? + .dequantize(vb.device())?; + let n_attn_heads = cfg.hidden_size / cfg.head_size; + let time_decay = vb + .get((1, 1, cfg.hidden_size), "time_decay")? + .dequantize(vb.device())?; + let time_faaaa = vb + .get((n_attn_heads, cfg.head_size), "time_faaaa")? + .dequantize(vb.device())?; + let time_mix_gate = vb + .get((1, 1, cfg.hidden_size), "time_mix_gate")? + .dequantize(vb.device())?; + let time_decay_w1 = vb + .get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")? + .dequantize(vb.device())?; + let time_decay_w2 = vb + .get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")? + .dequantize(vb.device())?; + let time_mix_w1 = vb + .get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")? + .dequantize(vb.device())?; + let time_mix_w2 = vb + .get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")? + .dequantize(vb.device())?; + Ok(Self { + key, + value, + receptance, + gate, + output, + ln_x, + time_mix_x, + time_mix_w, + time_mix_key, + time_mix_value, + time_mix_receptance, + time_decay, + time_faaaa, + time_mix_gate, + time_decay_w1, + time_decay_w2, + time_mix_w1, + time_mix_w2, + layer_id, + n_attn_heads, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let h = self.n_attn_heads; + let (b, t, s) = xs.dims3()?; + let s = s / h; + let (receptance, key, value, gate, w) = { + // extract key-value + let shifted = state.per_layer[self.layer_id].extract_key_value.clone(); + let shifted = if shifted.rank() == 2 { + shifted.unsqueeze(1)? + } else { + shifted + }; + + let sx = (&shifted - xs)?; + let xxx = (xs + &sx * &self.time_mix_x)?; + let xxx = xxx + .broadcast_matmul(&self.time_mix_w1)? + .tanh()? + .reshape((b * t, 5, ()))? + .transpose(0, 1)?; + + let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?; + + let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?); + + let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?; + let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?; + let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?; + let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?; + let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?; + + let w = (&self.time_decay + + xw.broadcast_matmul(&self.time_decay_w1)? + .tanh()? + .broadcast_matmul(&self.time_decay_w2)?)? + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + + let key = self.key.forward(&xk)?; + let value = self.value.forward(&xv)?; + let receptance = self.receptance.forward(&xr)?; + let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?; + state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?; + (receptance, key, value, gate, w) + }; + + // linear attention + let mut state_ = state.per_layer[self.layer_id].linear_attention.clone(); + let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?; + let value = value.reshape((b, t, h, s))?.transpose(1, 2)?; + let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?; + + let w = w.exp()?.neg()?.exp()?; + + let time_faaaa = + self.time_faaaa + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + + let mut out: Vec = Vec::with_capacity(t); + for t_ in 0..t { + let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?; + let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?; + let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?; + let at = kt.matmul(&vt)?; + let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?; + let out_ = rt.matmul(&rhs)?.squeeze(2)?; + state_ = (&at + w.broadcast_mul(&state_))?; + out.push(out_) + } + let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?; + let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?; + let out = (out * gate)?.apply(&self.output)?; + state.per_layer[self.layer_id].linear_attention = state_; + Ok(out) + } +} + +#[derive(Debug, Clone)] +struct FeedForward { + time_mix_key: Tensor, + time_mix_receptance: Tensor, + key: Linear, + receptance: Linear, + value: Linear, + layer_id: usize, +} + +impl FeedForward { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let int_size = cfg + .intermediate_size + .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32); + let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?; + let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?; + let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?; + let time_mix_key = vb + .get((1, 1, cfg.hidden_size), "time_mix_key")? + .dequantize(vb.device())?; + let time_mix_receptance = vb + .get((1, 1, cfg.hidden_size), "time_mix_receptance")? + .dequantize(vb.device())?; + Ok(Self { + key, + receptance, + value, + time_mix_key, + time_mix_receptance, + layer_id, + }) + } + + fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let shifted = state.per_layer[self.layer_id] + .feed_forward + .broadcast_sub(xs)?; + let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?; + let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?; + let key = key.apply(&self.key)?.relu()?.sqr()?; + let value = key.apply(&self.value)?; + let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?; + state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?; + let xs = (receptance * value)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Block { + pre_ln: Option, + ln1: LayerNorm, + ln2: LayerNorm, + attention: SelfAttention, + feed_forward: FeedForward, +} + +impl Block { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?; + let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?; + let pre_ln = if layer_id == 0 { + let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?; + Some(ln) + } else { + None + }; + let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?; + Ok(Self { + pre_ln, + ln1, + ln2, + attention, + feed_forward, + }) + } + + fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let xs = match self.pre_ln.as_ref() { + None => xs.clone(), + Some(pre_ln) => xs.apply(pre_ln)?, + }; + let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?; + let xs = (xs + attention)?; + let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?; + let xs = (xs + feed_forward)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embeddings: Embedding, + blocks: Vec, + ln_out: LayerNorm, + head: Linear, + rescale_every: usize, + layers_are_rescaled: bool, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("rwkv"); + let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?; + let mut blocks = Vec::with_capacity(cfg.num_hidden_layers); + let vb_b = vb_m.pp("blocks"); + for block_index in 0..cfg.num_hidden_layers { + let block = Block::new(block_index, cfg, vb_b.pp(block_index))?; + blocks.push(block) + } + let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?; + let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?; + Ok(Self { + embeddings, + blocks, + ln_out, + head, + rescale_every: cfg.rescale_every, + layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes. + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let (_b_size, _seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embeddings)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + xs = block.forward(&xs, state)?; + if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 { + xs = (xs / 2.)? + } + } + let xs = xs.apply(&self.ln_out)?.apply(&self.head)?; + state.pos += 1; + Ok(xs) + } +} diff --git a/candle-transformers/src/models/rwkv_v6.rs b/candle-transformers/src/models/rwkv_v6.rs new file mode 100644 index 00000000..457c351e --- /dev/null +++ b/candle-transformers/src/models/rwkv_v6.rs @@ -0,0 +1,295 @@ +use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; +use candle::{IndexOp, Result, Tensor}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; + +pub use crate::models::rwkv_v5::{Config, State, Tokenizer}; + +#[derive(Debug, Clone)] +struct SelfAttention { + key: Linear, + receptance: Linear, + value: Linear, + gate: Linear, + output: Linear, + ln_x: candle_nn::GroupNorm, + time_mix_x: Tensor, + time_mix_w: Tensor, + time_mix_key: Tensor, + time_mix_value: Tensor, + time_mix_receptance: Tensor, + time_decay: Tensor, + time_faaaa: Tensor, + time_mix_gate: Tensor, + time_decay_w1: Tensor, + time_decay_w2: Tensor, + time_mix_w1: Tensor, + time_mix_w2: Tensor, + layer_id: usize, + n_attn_heads: usize, +} + +impl SelfAttention { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size; + let attn_hidden_size = cfg.attention_hidden_size; + let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?; + let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?; + let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?; + let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?; + let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?; + let ln_x = candle_nn::group_norm( + hidden_size / cfg.head_size, + hidden_size, + 1e-5, + vb.pp("ln_x"), + )?; + + let time_mix_x = vb.get((1, 1, cfg.hidden_size), "time_mix_x")?; + let time_mix_w = vb.get((1, 1, cfg.hidden_size), "time_mix_w")?; + let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?; + let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?; + let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?; + let n_attn_heads = cfg.hidden_size / cfg.head_size; + let time_decay = vb.get((1, 1, cfg.hidden_size), "time_decay")?; + let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?; + let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?; + let time_decay_w1 = vb.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?; + let time_decay_w2 = vb.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?; + let time_mix_w1 = vb.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?; + let time_mix_w2 = vb.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?; + Ok(Self { + key, + value, + receptance, + gate, + output, + ln_x, + time_mix_x, + time_mix_w, + time_mix_key, + time_mix_value, + time_mix_receptance, + time_decay, + time_faaaa, + time_mix_gate, + time_decay_w1, + time_decay_w2, + time_mix_w1, + time_mix_w2, + layer_id, + n_attn_heads, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let h = self.n_attn_heads; + let (b, t, s) = xs.dims3()?; + let s = s / h; + let (receptance, key, value, gate, w) = { + // extract key-value + let shifted = state.per_layer[self.layer_id].extract_key_value.clone(); + let shifted = if shifted.rank() == 2 { + shifted.unsqueeze(1)? + } else { + shifted + }; + + let sx = (&shifted - xs)?; + let xxx = (xs + &sx * &self.time_mix_x)?; + let xxx = xxx + .broadcast_matmul(&self.time_mix_w1)? + .tanh()? + .reshape((b * t, 5, ()))? + .transpose(0, 1)?; + + let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?; + + let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?); + + let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?; + let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?; + let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?; + let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?; + let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?; + + let w = (&self.time_decay + + xw.broadcast_matmul(&self.time_decay_w1)? + .tanh()? + .broadcast_matmul(&self.time_decay_w2)?)? + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + + let key = self.key.forward(&xk)?; + let value = self.value.forward(&xv)?; + let receptance = self.receptance.forward(&xr)?; + let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?; + state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?; + (receptance, key, value, gate, w) + }; + + // linear attention + let mut state_ = state.per_layer[self.layer_id].linear_attention.clone(); + let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?; + let value = value.reshape((b, t, h, s))?.transpose(1, 2)?; + let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?; + + let w = w.exp()?.neg()?.exp()?; + + let time_faaaa = + self.time_faaaa + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + + let mut out: Vec = Vec::with_capacity(t); + for t_ in 0..t { + let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?; + let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?; + let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?; + let at = kt.matmul(&vt)?; + let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?; + let out_ = rt.matmul(&rhs)?.squeeze(2)?; + state_ = (&at + w.broadcast_mul(&state_))?; + out.push(out_) + } + let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?; + let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?; + let out = (out * gate)?.apply(&self.output)?; + state.per_layer[self.layer_id].linear_attention = state_; + Ok(out) + } +} + +#[derive(Debug, Clone)] +struct FeedForward { + time_mix_key: Tensor, + time_mix_receptance: Tensor, + key: Linear, + receptance: Linear, + value: Linear, + layer_id: usize, +} + +impl FeedForward { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let int_size = cfg + .intermediate_size + .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32); + let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?; + let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?; + let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?; + let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?; + let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?; + Ok(Self { + key, + receptance, + value, + time_mix_key, + time_mix_receptance, + layer_id, + }) + } + + fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let shifted = state.per_layer[self.layer_id] + .feed_forward + .broadcast_sub(xs)?; + let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?; + let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?; + let key = key.apply(&self.key)?.relu()?.sqr()?; + let value = key.apply(&self.value)?; + let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?; + state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?; + let xs = (receptance * value)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Block { + pre_ln: Option, + ln1: LayerNorm, + ln2: LayerNorm, + attention: SelfAttention, + feed_forward: FeedForward, +} + +impl Block { + fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?; + let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?; + let pre_ln = if layer_id == 0 { + let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?; + Some(ln) + } else { + None + }; + let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?; + Ok(Self { + pre_ln, + ln1, + ln2, + attention, + feed_forward, + }) + } + + fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let xs = match self.pre_ln.as_ref() { + None => xs.clone(), + Some(pre_ln) => xs.apply(pre_ln)?, + }; + let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?; + let xs = (xs + attention)?; + let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?; + let xs = (xs + feed_forward)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embeddings: Embedding, + blocks: Vec, + ln_out: LayerNorm, + head: Linear, + rescale_every: usize, + layers_are_rescaled: bool, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("rwkv"); + let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?; + let mut blocks = Vec::with_capacity(cfg.num_hidden_layers); + let vb_b = vb_m.pp("blocks"); + for block_index in 0..cfg.num_hidden_layers { + let block = Block::new(block_index, cfg, vb_b.pp(block_index))?; + blocks.push(block) + } + let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?; + let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?; + Ok(Self { + embeddings, + blocks, + ln_out, + head, + rescale_every: cfg.rescale_every, + layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes. + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let (_b_size, _seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embeddings)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + xs = block.forward(&xs, state)?; + if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 { + xs = (xs / 2.)? + } + } + let xs = xs.apply(&self.ln_out)?.apply(&self.head)?; + state.pos += 1; + Ok(xs) + } +}