fix: fix the codegeex4 model examples and transformers model (#2738)

* Update main.rs

* Update codegeex4_9b.rs

* Get things to compile.

* Add some default for when rope_ratio is missing.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
唐璜
2025-01-26 00:41:12 +08:00
committed by GitHub
parent 3164a19a5d
commit 333d94a19a
3 changed files with 60 additions and 40 deletions

View File

@ -1,9 +1,8 @@
use candle_transformers::models::codegeex4_9b::*;
use clap::Parser;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::codegeex4_9b::*;
use clap::Parser;
use hf_hub::{Repo, RepoType}; use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
@ -14,7 +13,7 @@ struct TextGeneration {
logits_processor: LogitsProcessor, logits_processor: LogitsProcessor,
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
verbose_prompt: bool, verbose: bool,
dtype: DType, dtype: DType,
} }
@ -24,22 +23,22 @@ impl TextGeneration {
model: Model, model: Model,
tokenizer: Tokenizer, tokenizer: Tokenizer,
seed: u64, seed: u64,
temp: Option<f64>, temp: f64,
top_p: Option<f64>, top_p: f64,
repeat_penalty: f32, repeat_penalty: f32,
repeat_last_n: usize, repeat_last_n: usize,
verbose_prompt: bool, verbose: bool,
device: &Device, device: &Device,
dtype: DType, dtype: DType,
) -> Self { ) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p); let logits_processor = LogitsProcessor::new(seed, Some(temp), Some(top_p));
Self { Self {
model, model,
tokenizer, tokenizer,
logits_processor, logits_processor,
repeat_penalty, repeat_penalty,
repeat_last_n, repeat_last_n,
verbose_prompt, verbose,
device: device.clone(), device: device.clone(),
dtype, dtype,
} }
@ -52,7 +51,7 @@ impl TextGeneration {
if tokens.is_empty() { if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.") panic!("Empty prompts are not supported in the chatglm model.")
} }
if self.verbose_prompt { if self.verbose {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n"); let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'"); println!("{id:7} -> '{token}'");
@ -101,7 +100,7 @@ impl TextGeneration {
.tokenizer .tokenizer
.decode(&[next_token], true) .decode(&[next_token], true)
.expect("Token error"); .expect("Token error");
if self.verbose_prompt { if self.verbose {
println!( println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]", "[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token count, next_token, token
@ -126,34 +125,35 @@ impl TextGeneration {
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
/// Run on CPU rather than on GPU. #[arg(name = "cache", short)]
#[arg(name = "cache", short, long, default_value = ".")] cache_path: Option<String>,
cache_path: String,
/// Run on CPU rather than on GPU.
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Display the token for the specified prompt. /// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)] #[arg(long)]
prompt: String, prompt: String,
/// The temperature used to generate samples. /// Display the tokens for the specified prompt and outputs.
#[arg(long)] #[arg(long)]
temperature: Option<f64>, verbose: bool,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.95)]
temperature: f64,
/// Nucleus sampling probability cutoff. /// Nucleus sampling probability cutoff.
#[arg(long)] #[arg(long, default_value_t = 0.8)]
top_p: Option<f64>, top_p: f64,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)] #[arg(long, short = 'n', default_value_t = 8192)]
sample_len: usize, sample_len: usize,
#[arg(long)] #[arg(long)]
@ -163,20 +163,19 @@ struct Args {
revision: Option<String>, revision: Option<String>,
#[arg(long)] #[arg(long)]
weight_file: Option<String>, weight_path: Option<String>,
#[arg(long)] #[arg(long)]
tokenizer: Option<String>, tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)] #[arg(long, default_value_t = 1.2)]
repeat_penalty: f32, repeat_penalty: f32,
/// The context size to consider for the repeat penalty. /// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)] #[arg(long, default_value_t = 64)]
repeat_last_n: usize, repeat_last_n: usize,
} }
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
println!( println!(
@ -188,17 +187,18 @@ fn main() -> anyhow::Result<()> {
); );
println!( println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.95), args.temperature, args.repeat_penalty, args.repeat_last_n
args.repeat_penalty,
args.repeat_last_n
); );
let start = std::time::Instant::now(); let start = std::time::Instant::now();
println!("cache path {}", args.cache_path); let api = match args.cache_path.as_ref() {
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())) None => hf_hub::api::sync::Api::new()?,
.build() Some(path) => {
.map_err(anyhow::Error::msg)?; hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(path.to_string().into()))
.build()
.map_err(anyhow::Error::msg)?
}
};
let model_id = match args.model_id { let model_id = match args.model_id {
Some(model_id) => model_id.to_string(), Some(model_id) => model_id.to_string(),
None => "THUDM/codegeex4-all-9b".to_string(), None => "THUDM/codegeex4-all-9b".to_string(),
@ -215,15 +215,22 @@ fn main() -> anyhow::Result<()> {
.get("tokenizer.json") .get("tokenizer.json")
.map_err(anyhow::Error::msg)?, .map_err(anyhow::Error::msg)?,
}; };
let filenames = match args.weight_file { let config_filename = match &args.weight_path {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], Some(path) => std::path::Path::new(path).join("config.json"),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, None => repo.get("config.json")?,
};
let filenames = match &args.weight_path {
Some(path) => {
candle_examples::hub_load_local_safetensors(path, "model.safetensors.index.json")?
}
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::codegeex4(); let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16
@ -243,7 +250,7 @@ fn main() -> anyhow::Result<()> {
args.top_p, args.top_p,
args.repeat_penalty, args.repeat_penalty,
args.repeat_last_n, args.repeat_last_n,
args.verbose_prompt, args.verbose,
&device, &device,
dtype, dtype,
); );

View File

@ -10,7 +10,11 @@ use crate::models::with_tracing::{linear_b as linear, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
#[derive(Debug, Clone)] fn default_one() -> usize {
1
}
#[derive(Debug, Clone, serde::Deserialize, Default)]
pub struct Config { pub struct Config {
pub num_layers: usize, pub num_layers: usize,
pub padded_vocab_size: usize, pub padded_vocab_size: usize,
@ -31,6 +35,8 @@ pub struct Config {
pub apply_query_key_layer_scaling: bool, pub apply_query_key_layer_scaling: bool,
pub attention_softmax_in_fp32: bool, pub attention_softmax_in_fp32: bool,
pub fp32_residual_connection: bool, pub fp32_residual_connection: bool,
#[serde(default = "default_one")]
pub rope_ratio: usize,
} }
impl Config { impl Config {
@ -55,6 +61,7 @@ impl Config {
apply_query_key_layer_scaling: true, apply_query_key_layer_scaling: true,
attention_softmax_in_fp32: true, attention_softmax_in_fp32: true,
fp32_residual_connection: false, fp32_residual_connection: false,
rope_ratio: 500,
} }
} }
} }
@ -68,9 +75,10 @@ impl RotaryEmbedding {
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> { fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
let rotary_dim = cfg.kv_channels; let rotary_dim = cfg.kv_channels;
let n_elem = rotary_dim / 2; let n_elem = rotary_dim / 2;
let base = 10_000f64 * cfg.rope_ratio as f64;
let inv_freq: Vec<_> = (0..n_elem) let inv_freq: Vec<_> = (0..n_elem)
.step_by(2) .step_by(2)
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32) .map(|i| 1f32 / base.powf(i as f64 / n_elem as f64) as f32)
.collect(); .collect();
let inv_freq_len = inv_freq.len(); let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;

View File

@ -8,6 +8,10 @@ use crate::models::with_tracing::{linear_b as linear, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
fn default_one() -> usize {
1
}
#[derive(Debug, Clone, serde::Deserialize, Default)] #[derive(Debug, Clone, serde::Deserialize, Default)]
pub struct Config { pub struct Config {
pub num_layers: usize, pub num_layers: usize,
@ -29,6 +33,7 @@ pub struct Config {
pub apply_query_key_layer_scaling: bool, pub apply_query_key_layer_scaling: bool,
pub attention_softmax_in_fp32: bool, pub attention_softmax_in_fp32: bool,
pub fp32_residual_connection: bool, pub fp32_residual_connection: bool,
#[serde(default = "default_one")]
pub rope_ratio: usize, pub rope_ratio: usize,
} }