mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the phi-3 model. (#2120)
* Add the phi-3 model. * Faster rope. * Bugfix. * Fix the detokenization.
This commit is contained in:
@ -7,11 +7,13 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -20,13 +22,14 @@ use tokenizers::Tokenizer;
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Phi(Phi),
|
||||
Phi3(Phi3),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
@ -49,7 +52,7 @@ impl TextGeneration {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
@ -61,7 +64,11 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
@ -73,13 +80,14 @@ impl TextGeneration {
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut pos = 0;
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
@ -88,6 +96,7 @@ impl TextGeneration {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Phi(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
@ -107,9 +116,11 @@ impl TextGeneration {
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
pos += context_size;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
@ -128,6 +139,8 @@ enum WhichModel {
|
||||
V1_5,
|
||||
#[value(name = "2")]
|
||||
V2,
|
||||
#[value(name = "3")]
|
||||
V3,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
@ -236,6 +249,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -253,9 +267,10 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"main".to_string()
|
||||
}
|
||||
WhichModel::V2
|
||||
| WhichModel::V3
|
||||
| WhichModel::PuffinPhiV2
|
||||
| WhichModel::PhiHermes => "main".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -264,9 +279,11 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
WhichModel::V1
|
||||
| WhichModel::V1_5
|
||||
| WhichModel::V2
|
||||
| WhichModel::V2Old
|
||||
| WhichModel::V3 => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -282,14 +299,19 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
WhichModel::V3 => anyhow::bail!(
|
||||
"use the quantized or quantized-phi examples for quantized phi-v3"
|
||||
),
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
|
||||
candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?
|
||||
}
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||
}
|
||||
@ -306,6 +328,9 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
WhichModel::V3 => {
|
||||
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
@ -320,7 +345,12 @@ fn main() -> Result<()> {
|
||||
};
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let dtype = if args.model == WhichModel::V3 && device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
@ -329,6 +359,13 @@ fn main() -> Result<()> {
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
Model::Phi(phi)
|
||||
}
|
||||
WhichModel::V3 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Phi3Config = serde_json::from_str(&config)?;
|
||||
let phi3 = Phi3::new(&config, vb)?;
|
||||
Model::Phi3(phi3)
|
||||
}
|
||||
WhichModel::V2Old => {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||
@ -421,6 +458,10 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Phi3(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input, 0)?
|
||||
}
|
||||
Model::Quantized(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
|
Reference in New Issue
Block a user