mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add support for phi-1.0 (#1093)
* Add support for phi-1.0 * Update the readme.
This commit is contained in:
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
@ -110,6 +110,14 @@ impl TextGeneration {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
#[value(name = "1")]
|
||||
V1,
|
||||
#[value(name = "1.5")]
|
||||
V1_5,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -140,11 +148,14 @@ struct Args {
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "microsoft/phi-1_5")]
|
||||
model_id: String,
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "refs/pr/18")]
|
||||
revision: String,
|
||||
#[arg(long, default_value = "1.5")]
|
||||
model: WhichModel,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
@ -189,18 +200,42 @@ fn main() -> Result<()> {
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let revision = match args.revision {
|
||||
Some(rev) => rev.to_string(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
"main".to_string()
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filename = match args.weight_file {
|
||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||
None => {
|
||||
if args.quantized {
|
||||
api.model("lmz/candle-quantized-phi".to_string())
|
||||
.get("model-q4k.gguf")?
|
||||
match args.model {
|
||||
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||
}
|
||||
} else {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
|
Reference in New Issue
Block a user