Use subcommands in llama2. (#292)

This commit is contained in:
Laurent Mazare
2023-08-01 05:57:41 +01:00
committed by GitHub
parent 1a07ff8d17
commit fa98ca0c35

View File

@ -6,7 +6,7 @@
extern crate intel_mkl_src;
mod model;
use clap::{Parser, ValueEnum};
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
@ -173,36 +173,19 @@ impl TransformerWeights {
}
}
#[derive(ValueEnum, Debug, Clone)]
enum Task {
Inference,
Evaluation,
Training,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The task to be performed, inference, training or evaluation.
#[clap(value_enum, default_value_t = Task::Inference)]
task: Task,
/// Run on CPU rather than on GPU.
#[derive(Parser, Debug, Clone)]
struct InferenceCmd {
/// The temperature used to generate samples.
#[arg(long)]
cpu: bool,
temperature: Option<f64>,
#[arg(long, default_value = "")]
prompt: String,
/// Config file in binary format.
#[arg(long)]
config: Option<String>,
/// Tokenizer config file.
#[arg(long)]
tokenizer: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
#[arg(long, default_value = "karpathy/tinyllamas")]
model_id: String,
@ -211,10 +194,10 @@ struct Args {
/// https://huggingface.co/karpathy/tinyllamas/tree/main
#[arg(long, default_value = "stories15M.bin")]
which_model: String,
}
#[arg(long, default_value = "")]
prompt: String,
#[derive(Parser, Debug, Clone)]
struct EvaluationCmd {
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
/// script from llama2.c https://github.com/karpathy/llama2.c
#[arg(long)]
@ -222,10 +205,71 @@ struct Args {
#[arg(long, default_value_t = 32)]
batch_size: usize,
/// Config file in binary format.
#[arg(long)]
config: Option<String>,
#[arg(long, default_value = "karpathy/tinyllamas")]
model_id: String,
/// The model to be used when getting it from the hub. Possible
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
/// https://huggingface.co/karpathy/tinyllamas/tree/main
#[arg(long, default_value = "stories15M.bin")]
which_model: String,
}
#[derive(Subcommand, Debug, Clone)]
enum Task {
Inference(InferenceCmd),
Evaluation(EvaluationCmd),
Training,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The task to be performed, inference, training or evaluation.
#[command(subcommand)]
task: Task,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Tokenizer config file.
#[arg(long)]
tokenizer: Option<String>,
}
impl Args {
fn tokenizer(&self) -> Result<Tokenizer> {
let tokenizer_path = match &self.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
api.get("tokenizer.json")?
}
};
Tokenizer::from_file(tokenizer_path).map_err(E::msg)
}
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match &args.task {
Task::Inference(cmd) => run_inference(cmd, &args)?,
Task::Evaluation(cmd) => run_eval(cmd, &args)?,
Task::Training => todo!(),
}
Ok(())
}
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead;
let config_path = match &args.config {
Some(config) => std::path::PathBuf::from(config),
None => {
@ -236,32 +280,9 @@ fn main() -> anyhow::Result<()> {
}
};
let tokenizer_path = match &args.tokenizer {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
api.get("tokenizer.json")?
}
};
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
let tokenizer = common_args.tokenizer()?;
match args.task {
Task::Inference => run_inference(tokenizer, &config_path, args)?,
Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
Task::Training => todo!(),
}
Ok(())
}
fn run_eval_file(
path: std::path::PathBuf,
config_path: &std::path::PathBuf,
args: Args,
) -> Result<()> {
use std::io::BufRead;
let device = candle_examples::device(args.cpu)?;
let device = candle_examples::device(common_args.cpu)?;
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
@ -269,50 +290,7 @@ fn run_eval_file(
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let bytes = std::fs::read(path)?;
// Tokens are encoded as u16.
let mut tokens = vec![0u16; bytes.len() / 2];
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
let tokens: Vec<u32> = tokens.into_iter().map(|u| u as u32).collect();
println!("dataset loaded: {} tokens", tokens.len());
let seq_len = model.config.seq_len;
let mut inputs = vec![];
let mut targets = vec![];
for start_idx in (0..tokens.len()).step_by(seq_len) {
if start_idx + seq_len + 1 > tokens.len() {
break;
}
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
let targets_ = Tensor::new(&tokens[1..], &device)?;
inputs.push(inputs_);
targets.push(targets_);
if inputs.len() >= args.batch_size {
let inp = Tensor::stack(&inputs, 0)?;
let tgt = Tensor::stack(&targets, 0)?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
inputs.clear();
targets.clear();
}
}
Ok(())
}
fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
use std::io::BufRead;
let device = candle_examples::device(args.cpu)?;
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let tokens = match args.pretokenized_dir {
let tokens = match &args.pretokenized_dir {
None => {
let api = hf_hub::api::sync::Api::new()?;
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
@ -365,8 +343,20 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
Ok(())
}
fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
let device = candle_examples::device(args.cpu)?;
fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let config_path = match &args.config {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
println!("loading the model weights from {}", args.model_id);
let api = api.model(args.model_id.clone());
api.get(&args.which_model)?
}
};
let tokenizer = common_args.tokenizer()?;
let device = candle_examples::device(common_args.cpu)?;
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
@ -381,7 +371,7 @@ fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: A
print!("{}", args.prompt);
let mut tokens = tokenizer
.encode(args.prompt, true)
.encode(args.prompt.clone(), true)
.map_err(E::msg)?
.get_ids()
.to_vec();