From e4c3a71f11c264f464c5c418a3bc810672f28119 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 21 Jan 2025 05:51:46 +0800 Subject: [PATCH] Fix GLM4 alignment issue (#2723) * Fix GLM4 alignment issue * Cleanups. --------- Co-authored-by: Laurent --- candle-book/Cargo.toml | 2 +- candle-examples/Cargo.toml | 2 +- candle-examples/examples/glm4/main.rs | 39 +++++++++++++++----------- candle-examples/src/lib.rs | 26 ++++++++++++++++- candle-transformers/src/models/glm4.rs | 7 +++-- 5 files changed, 54 insertions(+), 22 deletions(-) diff --git a/candle-book/Cargo.toml b/candle-book/Cargo.toml index dee55f20..f71645b4 100644 --- a/candle-book/Cargo.toml +++ b/candle-book/Cargo.toml @@ -25,7 +25,7 @@ cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true, optional = true } anyhow = { workspace = true } -tokio = "1.29.1" +tokio = "1.43.0" [dev-dependencies] byteorder = { workspace = true } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index df85302d..e679d01b 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -50,7 +50,7 @@ tracing = { workspace = true } tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } # Necessary to disambiguate with tokio in wasm examples which are 1.28.1 -tokio = "1.29.1" +tokio = "1.43.0" [build-dependencies] anyhow = { workspace = true } diff --git a/candle-examples/examples/glm4/main.rs b/candle-examples/examples/glm4/main.rs index 3fa948cb..c4a300cf 100644 --- a/candle-examples/examples/glm4/main.rs +++ b/candle-examples/examples/glm4/main.rs @@ -1,12 +1,10 @@ -use candle_transformers::models::glm4::*; -use clap::Parser; - use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; +use candle_transformers::models::glm4::*; +use clap::Parser; use hf_hub::{Repo, RepoType}; use tokenizers::Tokenizer; - struct TextGeneration { model: Model, device: Device, @@ -19,7 +17,8 @@ struct TextGeneration { impl TextGeneration { #[allow(clippy::too_many_arguments)] fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self { - let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); + let logits_processor = + LogitsProcessor::new(args.seed, Some(args.temperature), Some(args.top_p)); Self { model, tokenizer, @@ -125,12 +124,12 @@ struct Args { verbose: bool, /// The temperature used to generate samples. - #[arg(long)] - temperature: Option, + #[arg(long, default_value_t = 0.8)] + temperature: f64, /// Nucleus sampling probability cutoff. - #[arg(long)] - top_p: Option, + #[arg(long, default_value_t = 0.8)] + top_p: f64, /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] @@ -147,7 +146,7 @@ struct Args { revision: Option, #[arg(long)] - weight_file: Option, + weight_path: Option, #[arg(long)] tokenizer: Option, @@ -172,9 +171,7 @@ fn main() -> anyhow::Result<()> { ); println!( "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", - args.temperature.unwrap_or(0.6), - args.repeat_penalty, - args.repeat_last_n + args.temperature, args.repeat_penalty, args.repeat_last_n ); let start = std::time::Instant::now(); @@ -203,15 +200,23 @@ fn main() -> anyhow::Result<()> { .get("tokenizer.json") .map_err(anyhow::Error::msg)?, }; - let filenames = match args.weight_file.as_ref() { - Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], - None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + let config_filename = match &args.weight_path { + Some(path) => std::path::Path::new(path).join("config.json"), + _ => repo.get("config.json")?, }; + + let filenames = match &args.weight_path { + Some(path) => { + candle_examples::hub_load_local_safetensors(path, "model.safetensors.index.json")? + } + _ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); let start = std::time::Instant::now(); - let config = Config::glm4(); + let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let device = candle_examples::device(args.cpu)?; let dtype = if device.is_cuda() { DType::BF16 diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 5364bcb2..af49ab59 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -4,7 +4,6 @@ pub mod coco_classes; pub mod imagenet; pub mod token_output_stream; pub mod wav; - use candle::utils::{cuda_is_available, metal_is_available}; use candle::{Device, Result, Tensor}; @@ -147,3 +146,28 @@ pub fn hub_load_safetensors( .collect::>>()?; Ok(safetensors_files) } + +pub fn hub_load_local_safetensors>( + path: P, + json_file: &str, +) -> Result> { + let path = path.as_ref(); + let jsfile = std::fs::File::open(path.join(json_file))?; + let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?; + let weight_map = match json.get("weight_map") { + None => candle::bail!("no weight map in {json_file:?}"), + Some(serde_json::Value::Object(map)) => map, + Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), + }; + let mut safetensors_files = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_files.insert(file); + } + } + let safetensors_files: Vec<_> = safetensors_files + .into_iter() + .map(|v| path.join(v)) + .collect(); + Ok(safetensors_files) +} diff --git a/candle-transformers/src/models/glm4.rs b/candle-transformers/src/models/glm4.rs index de6581d0..433872ee 100644 --- a/candle-transformers/src/models/glm4.rs +++ b/candle-transformers/src/models/glm4.rs @@ -8,7 +8,7 @@ use crate::models::with_tracing::{linear_b as linear, Linear}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::VarBuilder; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Deserialize, Default)] pub struct Config { pub num_layers: usize, pub padded_vocab_size: usize, @@ -29,6 +29,7 @@ pub struct Config { pub apply_query_key_layer_scaling: bool, pub attention_softmax_in_fp32: bool, pub fp32_residual_connection: bool, + pub rope_ratio: usize, } impl Config { @@ -53,6 +54,7 @@ impl Config { apply_query_key_layer_scaling: true, attention_softmax_in_fp32: true, fp32_residual_connection: false, + rope_ratio: 500, } } } @@ -66,9 +68,10 @@ impl RotaryEmbedding { fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result { let rotary_dim = cfg.kv_channels; let n_elem = rotary_dim / 2; + let base = 10_000f64 * cfg.rope_ratio as f64; let inv_freq: Vec<_> = (0..n_elem) .step_by(2) - .map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32) + .map(|i| 1f32 / base.powf(i as f64 / n_elem as f64) as f32) .collect(); let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;