diff --git a/README.md b/README.md index 9bfa30d8..5c65ef68 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,9 @@ We also provide a some command line based examples using state of the art models experts 8x7b general LLM with better performance than a Llama 2 70B model with much faster inference. - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation. +- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs. +- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM + performance. - [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion. - [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual (English/Chinese) general LLMs with 6b and 34b parameters. @@ -193,6 +196,8 @@ If you have an addition to this list, please submit a pull request. - Replit-code-v1.5-3B. - Bert. - Yi-6B and Yi-34B. + - Qwen1.5. + - RWKV. - Quantized LLMs. - Llama 7b, 13b, 70b, as well as the chat and code variants. - Mistral 7b, and 7b instruct. @@ -210,7 +215,8 @@ If you have an addition to this list, please submit a pull request. - BLIP. - TrOCR. - Computer Vision Models. - - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT. + - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, + ConvNeXTv2. - yolo-v3, yolo-v8. - Segment-Anything Model (SAM). - File formats: load models from safetensors, npz, ggml, or PyTorch files. diff --git a/candle-examples/examples/rwkv/main.rs b/candle-examples/examples/rwkv/main.rs new file mode 100644 index 00000000..7fd1f76c --- /dev/null +++ b/candle-examples/examples/rwkv/main.rs @@ -0,0 +1,290 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::{Parser, ValueEnum}; + +use candle_transformers::models::rwkv_v5::{Config, Model, State}; + +use candle::{DType, Device, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + config: Config, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + config: Config, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + config, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + self.tokenizer.clear(); + let mut tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the token"), + }; + let mut state = State::new(1, &self.config, &self.device)?; + let mut next_logits = None; + for &t in tokens.iter() { + let input = Tensor::new(&[[t]], &self.device)?; + let logits = self.model.forward(&input, &mut state)?; + next_logits = Some(logits); + if let Some(t) = self.tokenizer.next_token(t)? { + print!("{t}") + } + } + std::io::stdout().flush()?; + + let start_gen = std::time::Instant::now(); + for _ in 0..sample_len { + let logits = match next_logits.as_ref() { + Some(logits) => logits, + None => anyhow::bail!("cannot work on an empty prompt"), + }; + let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let input = Tensor::new(&[[next_token]], &self.device)?; + next_logits = Some(self.model.forward(&input, &mut state)?) + } + let dt = start_gen.elapsed(); + if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)] +enum Which { + Eagle7b, + World1b5, + World3b, +} + +impl std::fmt::Display for Which { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl Which { + fn model_id(&self) -> &'static str { + match self { + Self::Eagle7b => "RWKV/HF_v5-Eagle-7B", + Self::World1b5 => "RWKV/rwkv-5-world-1b5", + Self::World3b => "RWKV/rwkv-5-world-3b", + } + } + + fn revision(&self) -> &'static str { + match self { + Self::Eagle7b => "refs/pr/1", + Self::World1b5 | Self::World3b => "refs/pr/2", + } + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 5000)] + sample_len: usize, + + #[arg(long, default_value = "world1b5")] + which: Which, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + weight_files: Option, + + #[arg(long)] + config_file: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id + .unwrap_or_else(|| args.which.model_id().to_string()), + RepoType::Model, + args.revision + .unwrap_or_else(|| args.which.revision().to_string()), + )); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => api + // TODO: Use the appropriate tokenizer here. + .model("EleutherAI/gpt-neox-20b".to_string()) + .get("tokenizer.json")?, + }; + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + let filenames = match args.weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => { + vec![repo.get("model.safetensors")?] + } + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let model = Model::new(&config, vb)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new( + model, + config, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + &device, + ); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index 7a920cb8..f8126394 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,13 +1,12 @@ use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; -use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; pub const MAX_SEQ_LEN: usize = 4096; -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { pub hidden_size: usize, pub intermediate_size: usize, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 769fd650..8eab4744 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -34,6 +34,7 @@ pub mod quantized_t5; pub mod qwen2; pub mod repvgg; pub mod resnet; +pub mod rwkv_v5; pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/rwkv_v5.rs b/candle-transformers/src/models/rwkv_v5.rs new file mode 100644 index 00000000..d8ea7b20 --- /dev/null +++ b/candle-transformers/src/models/rwkv_v5.rs @@ -0,0 +1,317 @@ +use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear}; +use candle::{DType, Device, IndexOp, Result, Tensor}; +use candle_nn::{embedding, Embedding, Module, VarBuilder}; + +fn default_num_attention_heads() -> usize { + 64 +} + +// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py +#[derive(Debug, Clone, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub attention_hidden_size: usize, + #[serde(default = "default_num_attention_heads")] + pub num_attention_heads: usize, + pub head_size: usize, + pub intermediate_size: Option, + pub layer_norm_epsilon: f64, + pub rescale_every: usize, +} + +struct StatePerLayer { + extract_key_value: Tensor, + linear_attention: Tensor, + feed_forward: Tensor, +} + +pub struct State { + per_layer: Vec, + pos: usize, +} + +impl State { + pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result { + let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers); + // Certainly a weird convention but taken from modeling_rwkv5.py + let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads; + for _layer_idx in 0..cfg.num_hidden_layers { + let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?; + let linear_attention = Tensor::zeros( + ( + batch_size, + num_attention_heads, + cfg.hidden_size / num_attention_heads, + cfg.hidden_size / num_attention_heads, + ), + DType::F32, + dev, + )?; + let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?; + per_layer.push(StatePerLayer { + extract_key_value, + linear_attention, + feed_forward, + }); + } + Ok(Self { per_layer, pos: 0 }) + } +} + +#[derive(Debug, Clone)] +struct SelfAttention { + key: Linear, + receptance: Linear, + value: Linear, + gate: Linear, + output: Linear, + ln_x: candle_nn::GroupNorm, + time_mix_key: Tensor, + time_mix_value: Tensor, + time_mix_receptance: Tensor, + time_decay: Tensor, + time_faaaa: Tensor, + time_mix_gate: Tensor, + layer_id: usize, + n_attn_heads: usize, +} + +impl SelfAttention { + pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size; + let attn_hidden_size = cfg.attention_hidden_size; + let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?; + let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?; + let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?; + let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?; + let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?; + let ln_x = candle_nn::group_norm( + hidden_size / cfg.head_size, + hidden_size, + 1e-5, + vb.pp("ln_x"), + )?; + let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?; + let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?; + let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?; + let n_attn_heads = cfg.hidden_size / cfg.head_size; + let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?; + let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?; + let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?; + Ok(Self { + key, + value, + receptance, + gate, + output, + ln_x, + time_mix_key, + time_mix_value, + time_mix_receptance, + time_decay, + time_faaaa, + time_mix_gate, + layer_id, + n_attn_heads, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let h = self.time_decay.dim(0)?; + let (b, t, s) = xs.dims3()?; + let s = s / h; + let (receptance, key, value, gate) = { + // exctract key-value + let shifted = state.per_layer[self.layer_id].extract_key_value.clone(); + let shifted = if shifted.rank() == 2 { + shifted.unsqueeze(1)? + } else { + shifted + }; + let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?; + let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?; + let receptance = ((xs * &self.time_mix_receptance)? + + &shifted * (1.0 - &self.time_mix_receptance)?)?; + let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?; + + let key = self.key.forward(&key)?; + let value = self.value.forward(&value)?; + let receptance = self.receptance.forward(&receptance)?; + let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?; + state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?; + (receptance, key, value, gate) + }; + // linear attention + let mut state_ = state.per_layer[self.layer_id].linear_attention.clone(); + let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?; + let value = value.reshape((b, t, h, s))?.transpose(1, 2)?; + let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?; + + let time_decay = self + .time_decay + .exp()? + .neg()? + .exp()? + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + let time_faaaa = + self.time_faaaa + .reshape(((), 1, 1))? + .reshape((self.n_attn_heads, (), 1))?; + + let mut out: Vec = Vec::with_capacity(t); + for t_ in 0..t { + // + let rt = receptance.i((.., .., t_..t_ + 1))?; + let kt = key.i((.., .., .., t_..t_ + 1))?; + let vt = value.i((.., .., t_..t_ + 1))?; + let at = kt.matmul(&vt)?; + let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?; + let out_ = rt.matmul(&rhs)?.squeeze(2)?; + state_ = (&at + time_decay.broadcast_mul(&state_))?; + out.push(out_) + } + let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?; + let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?; + let out = (out * gate)?.apply(&self.output)?; + state.per_layer[self.layer_id].linear_attention = state_; + Ok(out) + } +} + +#[derive(Debug, Clone)] +struct FeedForward { + time_mix_key: Tensor, + time_mix_receptance: Tensor, + key: Linear, + receptance: Linear, + value: Linear, + layer_id: usize, +} + +impl FeedForward { + pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let int_size = cfg + .intermediate_size + .unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32); + let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?; + let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?; + let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?; + let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?; + let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?; + Ok(Self { + key, + receptance, + value, + time_mix_key, + time_mix_receptance, + layer_id, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let shifted = &state.per_layer[self.layer_id].feed_forward; + let key = (xs.broadcast_mul(&self.time_mix_key)? + + shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?; + let receptance = (xs.broadcast_mul(&self.time_mix_receptance)? + + shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?; + let key = key.apply(&self.key)?.relu()?.sqr()?; + let value = key.apply(&self.value)?; + let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?; + state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?; + let xs = (receptance * value)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Block { + pre_ln: Option, + ln1: LayerNorm, + ln2: LayerNorm, + attention: SelfAttention, + feed_forward: FeedForward, +} + +impl Block { + pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result { + let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?; + let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?; + let pre_ln = if layer_id == 0 { + let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?; + Some(ln) + } else { + None + }; + let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?; + let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?; + Ok(Self { + pre_ln, + ln1, + ln2, + attention, + feed_forward, + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let xs = match self.pre_ln.as_ref() { + None => xs.clone(), + Some(pre_ln) => xs.apply(pre_ln)?, + }; + let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?; + let xs = (xs + attention)?; + let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?; + let xs = (xs + feed_forward)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embeddings: Embedding, + blocks: Vec, + ln_out: LayerNorm, + head: Linear, + rescale_every: usize, + layers_are_rescaled: bool, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("rwkv"); + let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?; + let mut blocks = Vec::with_capacity(cfg.num_hidden_layers); + let vb_b = vb_m.pp("blocks"); + for block_index in 0..cfg.num_hidden_layers { + let block = Block::new(block_index, cfg, vb_b.pp(block_index))?; + blocks.push(block) + } + let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?; + let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?; + Ok(Self { + embeddings, + blocks, + ln_out, + head, + rescale_every: cfg.rescale_every, + layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes. + }) + } + + pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result { + let (_b_size, _seq_len) = xs.dims2()?; + let mut xs = xs.apply(&self.embeddings)?; + for (block_idx, block) in self.blocks.iter().enumerate() { + xs = block.forward(&xs, state)?; + if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 { + xs = (xs / 2.)? + } + } + let xs = xs.apply(&self.ln_out)?.apply(&self.head)?; + state.pos += 1; + Ok(xs) + } +}