diff --git a/candle-examples/examples/csm/main.rs b/candle-examples/examples/csm/main.rs new file mode 100644 index 00000000..a4bafb55 --- /dev/null +++ b/candle-examples/examples/csm/main.rs @@ -0,0 +1,156 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use clap::Parser; + +use candle_transformers::models::csm::{Config, Model}; + +use candle::DType; +use candle_nn::VarBuilder; +use hf_hub::{api::sync::Api, Repo, RepoType}; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "1b")] + Csm1b, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + use_flash_attn: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long, default_value_t = 0.7)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 10000)] + sample_len: usize, + + /// The model size to use. + #[arg(long, default_value = "1b")] + which: Which, + + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + tokenizer: Option, + + #[arg(long)] + config: Option, + + #[arg(long)] + weights: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id, + None => { + let name = match args.which { + Which::Csm1b => "sesame/csm-1b", + }; + name.to_string() + } + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + let filenames = match args.weights { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => vec![repo.get("model.safetensors")?], + }; + println!("retrieved the files in {:?}", start.elapsed()); + + let start = std::time::Instant::now(); + let config: Config = match args.config { + Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?, + None => { + let config_file = repo.get("config.json")?; + serde_json::from_slice(&std::fs::read(config_file)?)? + } + }; + let device = candle_examples::device(args.cpu)?; + let (_model, _device) = { + let dtype = DType::F32; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = Model::new(&config, vb)?; + (model, device) + }; + + println!("loaded the model in {:?}", start.elapsed()); + + Ok(()) +} diff --git a/candle-transformers/src/models/csm.rs b/candle-transformers/src/models/csm.rs index 8804fc2b..e40fc637 100644 --- a/candle-transformers/src/models/csm.rs +++ b/candle-transformers/src/models/csm.rs @@ -9,7 +9,7 @@ /// use crate::models::encodec; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder}; +use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder}; use std::sync::Arc; #[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)] @@ -114,13 +114,17 @@ impl RotaryEmbedding { Ok((q_embed, k_embed)) } } +fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get((hidden_size,), "scale")?; + Ok(RmsNorm::new(weight, eps)) +} #[derive(Debug, Clone)] struct Attention { q_proj: Linear, k_proj: Linear, v_proj: Linear, - output_proj: Linear, + o_proj: Linear, rotary_emb: Arc, kv_cache: Option<(Tensor, Tensor)>, num_heads: usize, @@ -131,21 +135,24 @@ struct Attention { impl Attention { fn new(cfg: &LlamaConfig, rotary_emb: Arc, vb: VarBuilder) -> Result { + let head_dim = cfg.embed_dim / cfg.num_heads; + let kv_dim = cfg.num_kv_heads * head_dim; + let q_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("q_proj"))?; - let k_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("k_proj"))?; - let v_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("v_proj"))?; - let output_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("out_proj"))?; + let k_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("k_proj"))?; + let v_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("v_proj"))?; + let o_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("output_proj"))?; Ok(Self { q_proj, k_proj, v_proj, - output_proj, + o_proj, rotary_emb, kv_cache: None, num_heads: cfg.num_heads, num_kv_heads: cfg.num_kv_heads, num_kv_groups: cfg.num_heads / cfg.num_kv_heads, - head_dim: cfg.embed_dim / cfg.num_heads, + head_dim, }) } @@ -205,7 +212,7 @@ impl Attention { attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.num_heads * self.head_dim))? - .apply(&self.output_proj) + .apply(&self.o_proj) } fn clear_kv_cache(&mut self) {