mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Update the Phi model to use the updated architecture. (#1580)
* Update the Phi model to use the updated architecture. * Add more of the phi model. * Repeat KV + caching. * Apply the rotary embeddings. * Add support for the new phi model in the phi example. * Fix a couple glitches. * Fix a couple more glitches.
This commit is contained in:
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Phi(Phi),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
@ -84,6 +86,7 @@ impl TextGeneration {
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Phi(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
@ -117,7 +120,7 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
enum WhichModel {
|
||||
#[value(name = "1")]
|
||||
V1,
|
||||
@ -125,6 +128,9 @@ enum WhichModel {
|
||||
V1_5,
|
||||
#[value(name = "2")]
|
||||
V2,
|
||||
// TODO: Make this the default once it has been battle tested.
|
||||
#[value(name = "2-new")]
|
||||
V2New,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
}
|
||||
@ -230,7 +236,7 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::V2 => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V2 | WhichModel::V2New => "microsoft/phi-2".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -248,7 +254,9 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
WhichModel::V2 => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
|
||||
WhichModel::V2New | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"main".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -257,7 +265,9 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2New => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -270,14 +280,14 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||
WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::V2 | WhichModel::V2New => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 => candle_examples::hub_load_safetensors(
|
||||
WhichModel::V2 | WhichModel::V2New => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
@ -291,25 +301,35 @@ fn main() -> Result<()> {
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = match args.model {
|
||||
let config = || match args.model {
|
||||
WhichModel::V1 => Config::v1(),
|
||||
WhichModel::V1_5 => Config::v1_5(),
|
||||
WhichModel::V2 => Config::v2(),
|
||||
WhichModel::V2 | WhichModel::V2New => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let (model, device) = if args.model == WhichModel::V2New {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: PhiConfig = serde_json::from_str(&config)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
(Model::Phi(phi), device)
|
||||
} else if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||
let config = config();
|
||||
let model = match args.model {
|
||||
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
|
||||
WhichModel::V2 | WhichModel::V2New => QMixFormer::new_v2(&config, vb)?,
|
||||
_ => QMixFormer::new(&config, vb)?,
|
||||
};
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let config = config();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = match args.model {
|
||||
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
|
||||
WhichModel::V2 | WhichModel::V2New => MixFormer::new_v2(&config, vb)?,
|
||||
_ => MixFormer::new(&config, vb)?,
|
||||
};
|
||||
(Model::MixFormer(model), device)
|
||||
@ -392,6 +412,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Phi(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Quantized(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
|
Reference in New Issue
Block a user