mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some code to load the model.
This commit is contained in:
156
candle-examples/examples/csm/main.rs
Normal file
156
candle-examples/examples/csm/main.rs
Normal file
@ -0,0 +1,156 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::csm::{Config, Model};
|
||||
|
||||
use candle::DType;
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "1b")]
|
||||
Csm1b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.7)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "1b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Csm1b => "sesame/csm-1b",
|
||||
};
|
||||
name.to_string()
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let filenames = match args.weights {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![repo.get("model.safetensors")?],
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = match args.config {
|
||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||
None => {
|
||||
let config_file = repo.get("config.json")?;
|
||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (_model, _device) = {
|
||||
let dtype = DType::F32;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
(model, device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
Ok(())
|
||||
}
|
@ -9,7 +9,7 @@
|
||||
///
|
||||
use crate::models::encodec;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -114,13 +114,17 @@ impl RotaryEmbedding {
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||
let weight = vb.get((hidden_size,), "scale")?;
|
||||
Ok(RmsNorm::new(weight, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
output_proj: Linear,
|
||||
o_proj: Linear,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
num_heads: usize,
|
||||
@ -131,21 +135,24 @@ struct Attention {
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||
let kv_dim = cfg.num_kv_heads * head_dim;
|
||||
|
||||
let q_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("v_proj"))?;
|
||||
let output_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("out_proj"))?;
|
||||
let k_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("output_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
output_proj,
|
||||
o_proj,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
num_heads: cfg.num_heads,
|
||||
num_kv_heads: cfg.num_kv_heads,
|
||||
num_kv_groups: cfg.num_heads / cfg.num_kv_heads,
|
||||
head_dim: cfg.embed_dim / cfg.num_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
@ -205,7 +212,7 @@ impl Attention {
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
|
||||
.apply(&self.output_proj)
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
|
Reference in New Issue
Block a user