mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support both llama v1 and llama v2. (#272)
This commit is contained in:
@ -127,7 +127,11 @@ fn main() -> Result<()> {
|
|||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let config = Config::config_7b(args.use_flash_attn);
|
let config = if args.v1 {
|
||||||
|
Config::config_7b_v1(args.use_flash_attn)
|
||||||
|
} else {
|
||||||
|
Config::config_7b_v2(args.use_flash_attn)
|
||||||
|
};
|
||||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
let (llama, tokenizer_filename) = match args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
|
@ -18,7 +18,21 @@ pub struct Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn config_7b(use_flash_attn: bool) -> Self {
|
pub fn config_7b_v1(use_flash_attn: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
hidden_size: 4096,
|
||||||
|
intermediate_size: 11008,
|
||||||
|
vocab_size: 32000,
|
||||||
|
n_layer: 32,
|
||||||
|
n_head: 32,
|
||||||
|
n_embd: 4096,
|
||||||
|
n_key_value_head: 32,
|
||||||
|
use_flash_attn,
|
||||||
|
rms_norm_eps: 1e-6,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config_7b_v2(use_flash_attn: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
hidden_size: 4096,
|
hidden_size: 4096,
|
||||||
intermediate_size: 11008,
|
intermediate_size: 11008,
|
||||||
|
Reference in New Issue
Block a user