mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use subcommands in llama2. (#292)
This commit is contained in:
@ -6,7 +6,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod model;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
@ -173,36 +173,19 @@ impl TransformerWeights {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
enum Task {
|
||||
Inference,
|
||||
Evaluation,
|
||||
Training,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The task to be performed, inference, training or evaluation.
|
||||
#[clap(value_enum, default_value_t = Task::Inference)]
|
||||
task: Task,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct InferenceCmd {
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
temperature: Option<f64>,
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
/// Config file in binary format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
/// Tokenizer config file.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||
model_id: String,
|
||||
|
||||
@ -211,10 +194,10 @@ struct Args {
|
||||
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||
#[arg(long, default_value = "stories15M.bin")]
|
||||
which_model: String,
|
||||
}
|
||||
|
||||
#[arg(long, default_value = "")]
|
||||
prompt: String,
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct EvaluationCmd {
|
||||
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
|
||||
/// script from llama2.c https://github.com/karpathy/llama2.c
|
||||
#[arg(long)]
|
||||
@ -222,10 +205,71 @@ struct Args {
|
||||
|
||||
#[arg(long, default_value_t = 32)]
|
||||
batch_size: usize,
|
||||
|
||||
/// Config file in binary format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||
model_id: String,
|
||||
|
||||
/// The model to be used when getting it from the hub. Possible
|
||||
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
|
||||
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||
#[arg(long, default_value = "stories15M.bin")]
|
||||
which_model: String,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Task {
|
||||
Inference(InferenceCmd),
|
||||
Evaluation(EvaluationCmd),
|
||||
Training,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The task to be performed, inference, training or evaluation.
|
||||
#[command(subcommand)]
|
||||
task: Task,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Tokenizer config file.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(E::msg)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
match &args.task {
|
||||
Task::Inference(cmd) => run_inference(cmd, &args)?,
|
||||
Task::Evaluation(cmd) => run_eval(cmd, &args)?,
|
||||
Task::Training => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
let config_path = match &args.config {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
@ -236,32 +280,9 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
let tokenizer_path = match &args.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||
let tokenizer = common_args.tokenizer()?;
|
||||
|
||||
match args.task {
|
||||
Task::Inference => run_inference(tokenizer, &config_path, args)?,
|
||||
Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
|
||||
Task::Training => todo!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_eval_file(
|
||||
path: std::path::PathBuf,
|
||||
config_path: &std::path::PathBuf,
|
||||
args: Args,
|
||||
) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
@ -269,50 +290,7 @@ fn run_eval_file(
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let bytes = std::fs::read(path)?;
|
||||
// Tokens are encoded as u16.
|
||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
||||
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
|
||||
let tokens: Vec<u32> = tokens.into_iter().map(|u| u as u32).collect();
|
||||
println!("dataset loaded: {} tokens", tokens.len());
|
||||
|
||||
let seq_len = model.config.seq_len;
|
||||
let mut inputs = vec![];
|
||||
let mut targets = vec![];
|
||||
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
||||
if start_idx + seq_len + 1 > tokens.len() {
|
||||
break;
|
||||
}
|
||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
||||
let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
|
||||
let targets_ = Tensor::new(&tokens[1..], &device)?;
|
||||
inputs.push(inputs_);
|
||||
targets.push(targets_);
|
||||
if inputs.len() >= args.batch_size {
|
||||
let inp = Tensor::stack(&inputs, 0)?;
|
||||
let tgt = Tensor::stack(&targets, 0)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
inputs.clear();
|
||||
targets.clear();
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
|
||||
use std::io::BufRead;
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let tokens = match args.pretokenized_dir {
|
||||
let tokens = match &args.pretokenized_dir {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
|
||||
@ -365,8 +343,20 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let config_path = match &args.config {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
println!("loading the model weights from {}", args.model_id);
|
||||
let api = api.model(args.model_id.clone());
|
||||
api.get(&args.which_model)?
|
||||
}
|
||||
};
|
||||
|
||||
let tokenizer = common_args.tokenizer()?;
|
||||
|
||||
let device = candle_examples::device(common_args.cpu)?;
|
||||
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
@ -381,7 +371,7 @@ fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: A
|
||||
|
||||
print!("{}", args.prompt);
|
||||
let mut tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.encode(args.prompt.clone(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
Reference in New Issue
Block a user