mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
fix: fix the codegeex4 model examples and transformers model (#2738)
* Update main.rs * Update codegeex4_9b.rs * Get things to compile. * Add some default for when rope_ratio is missing. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -10,7 +10,11 @@ use crate::models::with_tracing::{linear_b as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
fn default_one() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||
pub struct Config {
|
||||
pub num_layers: usize,
|
||||
pub padded_vocab_size: usize,
|
||||
@ -31,6 +35,8 @@ pub struct Config {
|
||||
pub apply_query_key_layer_scaling: bool,
|
||||
pub attention_softmax_in_fp32: bool,
|
||||
pub fp32_residual_connection: bool,
|
||||
#[serde(default = "default_one")]
|
||||
pub rope_ratio: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -55,6 +61,7 @@ impl Config {
|
||||
apply_query_key_layer_scaling: true,
|
||||
attention_softmax_in_fp32: true,
|
||||
fp32_residual_connection: false,
|
||||
rope_ratio: 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -68,9 +75,10 @@ impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let rotary_dim = cfg.kv_channels;
|
||||
let n_elem = rotary_dim / 2;
|
||||
let base = 10_000f64 * cfg.rope_ratio as f64;
|
||||
let inv_freq: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
|
||||
.map(|i| 1f32 / base.powf(i as f64 / n_elem as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
|
@ -8,6 +8,10 @@ use crate::models::with_tracing::{linear_b as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
fn default_one() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||
pub struct Config {
|
||||
pub num_layers: usize,
|
||||
@ -29,6 +33,7 @@ pub struct Config {
|
||||
pub apply_query_key_layer_scaling: bool,
|
||||
pub attention_softmax_in_fp32: bool,
|
||||
pub fp32_residual_connection: bool,
|
||||
#[serde(default = "default_one")]
|
||||
pub rope_ratio: usize,
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user