diff --git a/candle-examples/examples/llama/model.rs b/candle-examples/examples/llama/model.rs index 940c980c..751b5902 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-examples/examples/llama/model.rs @@ -13,7 +13,7 @@ pub struct LlamaConfig { pub vocab_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, - pub num_key_value_heads: usize, + pub num_key_value_heads: Option, pub rms_norm_eps: f64, } @@ -25,7 +25,7 @@ impl LlamaConfig { vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, num_attention_heads: self.num_attention_heads, - num_key_value_heads: self.num_key_value_heads, + num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads), rms_norm_eps: self.rms_norm_eps, use_flash_attn, }