Fix GLM4 alignment issue (#2723)

* Fix GLM4 alignment issue

* Cleanups.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Guoqing Bao
2025-01-21 05:51:46 +08:00
committed by GitHub
parent 17cbbe4286
commit e4c3a71f11
5 changed files with 54 additions and 22 deletions

View File

@ -25,7 +25,7 @@ cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
image = { workspace = true, optional = true } image = { workspace = true, optional = true }
anyhow = { workspace = true } anyhow = { workspace = true }
tokio = "1.29.1" tokio = "1.43.0"
[dev-dependencies] [dev-dependencies]
byteorder = { workspace = true } byteorder = { workspace = true }

View File

@ -50,7 +50,7 @@ tracing = { workspace = true }
tracing-chrome = { workspace = true } tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 # Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1" tokio = "1.43.0"
[build-dependencies] [build-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -1,12 +1,10 @@
use candle_transformers::models::glm4::*;
use clap::Parser;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::glm4::*;
use clap::Parser;
use hf_hub::{Repo, RepoType}; use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
struct TextGeneration { struct TextGeneration {
model: Model, model: Model,
device: Device, device: Device,
@ -19,7 +17,8 @@ struct TextGeneration {
impl TextGeneration { impl TextGeneration {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self { fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let logits_processor =
LogitsProcessor::new(args.seed, Some(args.temperature), Some(args.top_p));
Self { Self {
model, model,
tokenizer, tokenizer,
@ -125,12 +124,12 @@ struct Args {
verbose: bool, verbose: bool,
/// The temperature used to generate samples. /// The temperature used to generate samples.
#[arg(long)] #[arg(long, default_value_t = 0.8)]
temperature: Option<f64>, temperature: f64,
/// Nucleus sampling probability cutoff. /// Nucleus sampling probability cutoff.
#[arg(long)] #[arg(long, default_value_t = 0.8)]
top_p: Option<f64>, top_p: f64,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
@ -147,7 +146,7 @@ struct Args {
revision: Option<String>, revision: Option<String>,
#[arg(long)] #[arg(long)]
weight_file: Option<String>, weight_path: Option<String>,
#[arg(long)] #[arg(long)]
tokenizer: Option<String>, tokenizer: Option<String>,
@ -172,9 +171,7 @@ fn main() -> anyhow::Result<()> {
); );
println!( println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.6), args.temperature, args.repeat_penalty, args.repeat_last_n
args.repeat_penalty,
args.repeat_last_n
); );
let start = std::time::Instant::now(); let start = std::time::Instant::now();
@ -203,15 +200,23 @@ fn main() -> anyhow::Result<()> {
.get("tokenizer.json") .get("tokenizer.json")
.map_err(anyhow::Error::msg)?, .map_err(anyhow::Error::msg)?,
}; };
let filenames = match args.weight_file.as_ref() { let config_filename = match &args.weight_path {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], Some(path) => std::path::Path::new(path).join("config.json"),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, _ => repo.get("config.json")?,
}; };
let filenames = match &args.weight_path {
Some(path) => {
candle_examples::hub_load_local_safetensors(path, "model.safetensors.index.json")?
}
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::glm4(); let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16

View File

@ -4,7 +4,6 @@ pub mod coco_classes;
pub mod imagenet; pub mod imagenet;
pub mod token_output_stream; pub mod token_output_stream;
pub mod wav; pub mod wav;
use candle::utils::{cuda_is_available, metal_is_available}; use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor}; use candle::{Device, Result, Tensor};
@ -147,3 +146,28 @@ pub fn hub_load_safetensors(
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
Ok(safetensors_files) Ok(safetensors_files)
} }
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file);
}
}
let safetensors_files: Vec<_> = safetensors_files
.into_iter()
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}

View File

@ -8,7 +8,7 @@ use crate::models::with_tracing::{linear_b as linear, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
#[derive(Debug, Clone)] #[derive(Debug, Clone, serde::Deserialize, Default)]
pub struct Config { pub struct Config {
pub num_layers: usize, pub num_layers: usize,
pub padded_vocab_size: usize, pub padded_vocab_size: usize,
@ -29,6 +29,7 @@ pub struct Config {
pub apply_query_key_layer_scaling: bool, pub apply_query_key_layer_scaling: bool,
pub attention_softmax_in_fp32: bool, pub attention_softmax_in_fp32: bool,
pub fp32_residual_connection: bool, pub fp32_residual_connection: bool,
pub rope_ratio: usize,
} }
impl Config { impl Config {
@ -53,6 +54,7 @@ impl Config {
apply_query_key_layer_scaling: true, apply_query_key_layer_scaling: true,
attention_softmax_in_fp32: true, attention_softmax_in_fp32: true,
fp32_residual_connection: false, fp32_residual_connection: false,
rope_ratio: 500,
} }
} }
} }
@ -66,9 +68,10 @@ impl RotaryEmbedding {
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> { fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
let rotary_dim = cfg.kv_channels; let rotary_dim = cfg.kv_channels;
let n_elem = rotary_dim / 2; let n_elem = rotary_dim / 2;
let base = 10_000f64 * cfg.rope_ratio as f64;
let inv_freq: Vec<_> = (0..n_elem) let inv_freq: Vec<_> = (0..n_elem)
.step_by(2) .step_by(2)
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32) .map(|i| 1f32 / base.powf(i as f64 / n_elem as f64) as f32)
.collect(); .collect();
let inv_freq_len = inv_freq.len(); let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;