mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Adding support for codellama in examples.
Codellama requires bf16 for now (error to convert from bf16 to f16). Multiprocess demo not functional for it because flash-attn only supports f16 for now.
This commit is contained in:
@ -15,6 +15,12 @@ pub struct LlamaConfig {
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
@ -27,6 +33,7 @@ impl LlamaConfig {
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
@ -41,6 +48,7 @@ pub struct Config {
|
||||
pub num_key_value_heads: usize,
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -54,6 +62,7 @@ impl Config {
|
||||
num_key_value_heads: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,6 +76,7 @@ impl Config {
|
||||
num_key_value_heads: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -103,7 +113,7 @@ impl Cache {
|
||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
|
Reference in New Issue
Block a user