This commit is contained in:
Nicolas Patry
2023-08-16 10:41:00 +02:00
parent 76804730c6
commit 33c882ea74

View File

@ -1,8 +1,8 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use serde::Deserialize;
use super::MAX_SEQ_LEN;
@ -17,9 +17,9 @@ pub struct LlamaConfig {
pub rms_norm_eps: f64,
}
impl LlamaConfig{
pub fn into_config(&self, use_flash_attn: bool) -> Config{
Config{
impl LlamaConfig {
pub fn into_config(self, use_flash_attn: bool) -> Config {
Config {
hidden_size: self.hidden_size,
intermediate_size: self.intermediate_size,
vocab_size: self.vocab_size,
@ -27,7 +27,7 @@ impl LlamaConfig{
num_attention_heads: self.num_attention_heads,
num_key_value_heads: self.num_key_value_heads,
rms_norm_eps: self.rms_norm_eps,
use_flash_attn
use_flash_attn,
}
}
}