diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 720a4441..1dd507ff 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -123,6 +123,8 @@ enum WhichModel { V1, #[value(name = "1.5")] V1_5, + #[value(name = "2")] + V2, PuffinPhiV2, PhiHermes, } @@ -225,6 +227,7 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "microsoft/phi-1".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), + WhichModel::V2 => "microsoft/phi-2".to_string(), WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { "lmz/candle-quantized-phi".to_string() } @@ -241,7 +244,9 @@ fn main() -> Result<()> { match args.model { WhichModel::V1 => "refs/pr/2".to_string(), WhichModel::V1_5 => "refs/pr/18".to_string(), - WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(), + WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { + "main".to_string() + } } } } @@ -250,27 +255,32 @@ fn main() -> Result<()> { let tokenizer_filename = match args.tokenizer { Some(file) => std::path::PathBuf::from(file), None => match args.model { - WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?, + WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?, WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { repo.get("tokenizer-puffin-phi-v2.json")? } }, }; - let filename = match args.weight_file { - Some(weight_file) => std::path::PathBuf::from(weight_file), + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], None => { if args.quantized { match args.model { - WhichModel::V1 => repo.get("model-v1-q4k.gguf")?, - WhichModel::V1_5 => repo.get("model-q4k.gguf")?, - WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?, - WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?, + WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?], + WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?], + WhichModel::V2 => anyhow::bail!("phi-2 is not supported in quantized mode"), + WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], + WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], } } else { match args.model { - WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?, - WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?, - WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?, + WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], + WhichModel::V2 => vec![ + repo.get("model-00001-of-00002.safetensors")?, + repo.get("model-00002-of-00002.safetensors")?, + ], + WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?], + WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?], } } } @@ -282,17 +292,21 @@ fn main() -> Result<()> { let config = match args.model { WhichModel::V1 => Config::v1(), WhichModel::V1_5 => Config::v1_5(), + WhichModel::V2 => Config::v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(), }; let (model, device) = if args.quantized { - let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?; + let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?; let model = QMixFormer::new(&config, vb)?; (Model::Quantized(model), Device::Cpu) } else { let device = candle_examples::device(args.cpu)?; - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = MixFormer::new(&config, vb)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; + let model = match args.model { + WhichModel::V2 => MixFormer::new_v2(&config, vb)?, + _ => MixFormer::new(&config, vb)?, + }; (Model::MixFormer(model), device) }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e822ca14..b0e2fb88 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -57,6 +57,22 @@ impl Config { } } + pub fn v2() -> Self { + Self { + vocab_size: 51200, + n_positions: 2048, + n_embd: 2560, + n_layer: 32, + n_inner: None, + n_head: 32, + rotary_dim: usize::min(32, 2560 / 32), + activation_function: Activation::Gelu, + layer_norm_epsilon: 1e-5, + tie_word_embeddings: false, + pad_vocab_size_multiple: 64, + } + } + // https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json pub fn puffin_phi_v2() -> Self { Self { @@ -372,6 +388,24 @@ pub struct MixFormerSequentialForCausalLM { } impl MixFormerSequentialForCausalLM { + pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result { + let vb_head = vb.pp("lm_head"); + let vb = vb.pp("transformer"); + let embedding = Embedding::new(cfg, vb.pp("embd"))?; + let mut blocks = Vec::new(); + for i in 0..cfg.n_layer { + let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?; + blocks.push(block) + } + let head = CausalLMHead::new(cfg, vb_head)?; + Ok(Self { + embedding, + blocks, + head, + span: tracing::span!(tracing::Level::TRACE, "mixformer"), + }) + } + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let vb = vb.pp("layers"); let embedding = Embedding::new(cfg, vb.pp(0))?;