diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index e3d2550e..23939a1f 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -18,7 +18,7 @@ use clap::Parser; use candle::{DType, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; -use hf_hub::api::sync::Api; +use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; mod model; @@ -59,9 +59,9 @@ struct Args { #[arg(long)] prompt: Option, - /// Use f32 computations rather than f16. + /// Use different dtype than f16 #[arg(long)] - use_f32: bool, + dtype: Option, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] @@ -70,6 +70,9 @@ struct Args { #[arg(long)] model_id: Option, + #[arg(long)] + revision: Option, + #[arg(long)] v1: bool, @@ -97,7 +100,13 @@ fn main() -> Result<()> { }; let device = candle_examples::device(args.cpu)?; - let dtype = if args.use_f32 { DType::F32 } else { DType::F16 }; + let dtype = match args.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => panic!("Unsupported dtype {dtype}"), + None => DType::F16, + }; let (llama, tokenizer_filename, cache) = match args.npy { Some(filename) => { let config = if args.v1 { @@ -120,7 +129,8 @@ fn main() -> Result<()> { } }); println!("loading the model weights from {model_id}"); - let api = api.model(model_id); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); let tokenizer_filename = match &args.local_weights { Some(path) => (path.to_owned() + "tokenizer.json").into(), diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 561c2939..275856e0 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -15,6 +15,12 @@ pub struct LlamaConfig { pub num_attention_heads: usize, pub num_key_value_heads: Option, pub rms_norm_eps: f64, + #[serde(default = "default_rope")] + pub rope_theta: f32, +} + +fn default_rope() -> f32 { + 10_000.0 } impl LlamaConfig { @@ -27,6 +33,7 @@ impl LlamaConfig { num_attention_heads: self.num_attention_heads, num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), rms_norm_eps: self.rms_norm_eps, + rope_theta: self.rope_theta, use_flash_attn, } } @@ -41,6 +48,7 @@ pub struct Config { pub num_key_value_heads: usize, pub use_flash_attn: bool, pub rms_norm_eps: f64, + pub rope_theta: f32, } impl Config { @@ -54,6 +62,7 @@ impl Config { num_key_value_heads: 32, use_flash_attn, rms_norm_eps: 1e-6, + rope_theta: 10_000.0, } } @@ -67,6 +76,7 @@ impl Config { num_key_value_heads: 32, use_flash_attn, rms_norm_eps: 1e-5, + rope_theta: 10_000.0, } } } @@ -103,7 +113,7 @@ impl Cache { let n_elem = config.hidden_size / config.num_attention_heads; let theta: Vec<_> = (0..n_elem) .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) + .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) .collect(); let theta = Tensor::new(theta.as_slice(), device)?; let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index d6d0d14e..063070b3 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -17,7 +17,7 @@ use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use cudarc::driver::safe::CudaDevice; use cudarc::nccl::safe::{Comm, Id}; -use hf_hub::api::sync::Api; +use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; use std::rc::Rc; @@ -108,6 +108,12 @@ struct Args { #[arg(long)] model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + dtype: Option, } fn main() -> Result<()> { @@ -115,8 +121,13 @@ fn main() -> Result<()> { let args = Args::parse(); - let config = Config::config_7b(); - let dtype = DType::F16; + let dtype = match args.dtype.as_deref() { + Some("f16") => DType::F16, + Some("bf16") => DType::BF16, + Some("f32") => DType::F32, + Some(dtype) => panic!("Unsupported dtype {dtype}"), + None => DType::F16, + }; let api = Api::new()?; @@ -124,7 +135,10 @@ fn main() -> Result<()> { .model_id .unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string()); println!("loading the model weights from {model_id}"); - let api = api.model(model_id); + let revision = args.revision.unwrap_or("main".to_string()); + let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let config_filename = api.get("config.json")?; + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let tokenizer_filename = api.get("tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ @@ -185,7 +199,7 @@ fn main() -> Result<()> { println!("Rank {rank:?} spawned"); let device = Device::new_cuda(i)?; - let cache = model::Cache::new(&config, &device)?; + let cache = model::Cache::new(dtype, &config, &device)?; println!("building the model"); let handles = filenames diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index 1e7cafa2..b146b42d 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -3,6 +3,7 @@ use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shap use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; +use serde::Deserialize; use std::rc::Rc; use std::sync::{Arc, Mutex}; @@ -110,26 +111,34 @@ impl TensorParallelRowLinear { } } +#[derive(Deserialize)] pub struct Config { pub hidden_size: usize, pub intermediate_size: usize, pub vocab_size: usize, - pub n_layer: usize, - pub n_head: usize, - pub n_embd: usize, - pub n_key_value_head: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + #[serde(default = "default_rope")] + pub rope_theta: f32, +} + +fn default_rope() -> f32 { + 10_000.0 } impl Config { pub fn config_7b() -> Self { Self { - hidden_size: 4096, intermediate_size: 11008, vocab_size: 32000, - n_layer: 32, - n_head: 32, - n_embd: 4096, - n_key_value_head: 32, + num_hidden_layers: 32, + num_attention_heads: 32, + hidden_size: 4096, + num_key_value_heads: 32, + rms_norm_eps: 1e-5, + rope_theta: 10_000.0, } } } @@ -143,12 +152,12 @@ pub struct Cache { } impl Cache { - pub fn new(config: &Config, device: &Device) -> Result { + pub fn new(dtype: DType, config: &Config, device: &Device) -> Result { // precompute freqs_cis - let n_elem = config.n_embd / config.n_head; + let n_elem = config.hidden_size / config.num_attention_heads; let theta: Vec<_> = (0..n_elem) .step_by(2) - .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) + .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) .collect(); let theta = Tensor::new(theta.as_slice(), device)?; let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? @@ -158,10 +167,10 @@ impl Cache { // This is different from the paper, see: // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?; - let cos = idx_theta.cos()?.to_dtype(DType::F16)?; - let sin = idx_theta.sin()?.to_dtype(DType::F16)?; + let cos = idx_theta.cos()?.to_dtype(dtype)?; + let sin = idx_theta.sin()?.to_dtype(dtype)?; Ok(Self { - kvs: Arc::new(Mutex::new(vec![None; config.n_layer])), + kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])), cos, sin, }) @@ -185,21 +194,21 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result { struct CausalSelfAttention { qkv_proj: TensorParallelColumnLinear, o_proj: TensorParallelRowLinear, - n_head: usize, - n_key_value_head: usize, + num_attention_heads: usize, + num_key_value_heads: usize, head_dim: usize, cache: Cache, } impl CausalSelfAttention { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { - let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?; + let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; - let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?; - let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?; - let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?; - let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?; + let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; + let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; + let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?; let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?; let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?; Ok(rope) @@ -209,30 +218,31 @@ impl CausalSelfAttention { let (b_sz, seq_len, _) = x.shape().dims3()?; let qkv = self.qkv_proj.forward(x)?; - let n_embd = self.n_head * self.head_dim; + let hidden_size = self.num_attention_heads * self.head_dim; - let q = qkv.i((.., .., ..self.n_head * self.head_dim))?; + let q = qkv.i((.., .., ..self.num_attention_heads * self.head_dim))?; let k = qkv.i(( .., .., - self.n_head * self.head_dim - ..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim, + self.num_attention_heads * self.head_dim + ..self.num_attention_heads * self.head_dim + + self.num_key_value_heads * self.head_dim, ))?; let v = qkv.i(( .., .., - self.n_head * self.head_dim + self.n_key_value_head * self.head_dim.., + self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim.., ))?; // todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape()); let q = q - .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? .transpose(1, 2)?; let k = k - .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .transpose(1, 2)?; let mut v = v - .reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))? + .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .transpose(1, 2)?; let q = self.apply_rotary_emb(&q, index_pos)?; @@ -266,13 +276,13 @@ impl CausalSelfAttention { let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)? .transpose(1, 2)?; // Convert to contiguous as matmul doesn't support strided vs for now. - let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; let y = self.o_proj.forward(&y)?; Ok(y) } fn repeat_kv(&self, x: Tensor) -> Result { - let n_rep = self.n_head / self.n_key_value_head; + let n_rep = self.num_attention_heads / self.num_key_value_heads; if n_rep == 1 { Ok(x) } else { @@ -295,9 +305,9 @@ impl CausalSelfAttention { Ok(Self { qkv_proj, o_proj, - n_head: cfg.n_head / comm.world_size(), - n_key_value_head: cfg.n_key_value_head / comm.world_size(), - head_dim: cfg.hidden_size / cfg.n_head, + num_attention_heads: cfg.num_attention_heads / comm.world_size(), + num_key_value_heads: cfg.num_key_value_heads / comm.world_size(), + head_dim: cfg.hidden_size / cfg.num_attention_heads, cache: cache.clone(), }) } @@ -409,7 +419,7 @@ impl Llama { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?; - let blocks: Vec<_> = (0..cfg.n_layer) + let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| { Block::load( vb.pp(&format!("model.layers.{i}")),