mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Use subcommands in llama2. (#292)
This commit is contained in:
@ -6,7 +6,7 @@
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
mod model;
|
mod model;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||||
@ -173,36 +173,19 @@ impl TransformerWeights {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(ValueEnum, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
enum Task {
|
struct InferenceCmd {
|
||||||
Inference,
|
/// The temperature used to generate samples.
|
||||||
Evaluation,
|
|
||||||
Training,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// The task to be performed, inference, training or evaluation.
|
|
||||||
#[clap(value_enum, default_value_t = Task::Inference)]
|
|
||||||
task: Task,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
/// Config file in binary format.
|
/// Config file in binary format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
config: Option<String>,
|
config: Option<String>,
|
||||||
|
|
||||||
/// Tokenizer config file.
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer: Option<String>,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "karpathy/tinyllamas")]
|
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||||
model_id: String,
|
model_id: String,
|
||||||
|
|
||||||
@ -211,10 +194,10 @@ struct Args {
|
|||||||
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||||
#[arg(long, default_value = "stories15M.bin")]
|
#[arg(long, default_value = "stories15M.bin")]
|
||||||
which_model: String,
|
which_model: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[arg(long, default_value = "")]
|
#[derive(Parser, Debug, Clone)]
|
||||||
prompt: String,
|
struct EvaluationCmd {
|
||||||
|
|
||||||
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
|
/// A directory with the pre-tokenized dataset in the format generated by the tinystories.py
|
||||||
/// script from llama2.c https://github.com/karpathy/llama2.c
|
/// script from llama2.c https://github.com/karpathy/llama2.c
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -222,10 +205,71 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long, default_value_t = 32)]
|
#[arg(long, default_value_t = 32)]
|
||||||
batch_size: usize,
|
batch_size: usize,
|
||||||
|
|
||||||
|
/// Config file in binary format.
|
||||||
|
#[arg(long)]
|
||||||
|
config: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "karpathy/tinyllamas")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
/// The model to be used when getting it from the hub. Possible
|
||||||
|
/// values are 'stories15M.bin', 'stories42M.bin', see more at:
|
||||||
|
/// https://huggingface.co/karpathy/tinyllamas/tree/main
|
||||||
|
#[arg(long, default_value = "stories15M.bin")]
|
||||||
|
which_model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug, Clone)]
|
||||||
|
enum Task {
|
||||||
|
Inference(InferenceCmd),
|
||||||
|
Evaluation(EvaluationCmd),
|
||||||
|
Training,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// The task to be performed, inference, training or evaluation.
|
||||||
|
#[command(subcommand)]
|
||||||
|
task: Task,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Tokenizer config file.
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn tokenizer(&self) -> Result<Tokenizer> {
|
||||||
|
let tokenizer_path = match &self.tokenizer {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(tokenizer_path).map_err(E::msg)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
match &args.task {
|
||||||
|
Task::Inference(cmd) => run_inference(cmd, &args)?,
|
||||||
|
Task::Evaluation(cmd) => run_eval(cmd, &args)?,
|
||||||
|
Task::Training => todo!(),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||||
|
use std::io::BufRead;
|
||||||
|
|
||||||
let config_path = match &args.config {
|
let config_path = match &args.config {
|
||||||
Some(config) => std::path::PathBuf::from(config),
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
None => {
|
None => {
|
||||||
@ -236,32 +280,9 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let tokenizer_path = match &args.tokenizer {
|
let tokenizer = common_args.tokenizer()?;
|
||||||
Some(config) => std::path::PathBuf::from(config),
|
|
||||||
None => {
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model("hf-internal-testing/llama-tokenizer".to_string());
|
|
||||||
api.get("tokenizer.json")?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
|
||||||
|
|
||||||
match args.task {
|
let device = candle_examples::device(common_args.cpu)?;
|
||||||
Task::Inference => run_inference(tokenizer, &config_path, args)?,
|
|
||||||
Task::Evaluation => run_eval(tokenizer, &config_path, args)?,
|
|
||||||
Task::Training => todo!(),
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_eval_file(
|
|
||||||
path: std::path::PathBuf,
|
|
||||||
config_path: &std::path::PathBuf,
|
|
||||||
args: Args,
|
|
||||||
) -> Result<()> {
|
|
||||||
use std::io::BufRead;
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||||
@ -269,50 +290,7 @@ fn run_eval_file(
|
|||||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
let model = Llama::load(vb, &cache, config)?;
|
||||||
|
|
||||||
let bytes = std::fs::read(path)?;
|
let tokens = match &args.pretokenized_dir {
|
||||||
// Tokens are encoded as u16.
|
|
||||||
let mut tokens = vec![0u16; bytes.len() / 2];
|
|
||||||
std::io::Cursor::new(bytes).read_u16_into::<LittleEndian>(&mut tokens)?;
|
|
||||||
let tokens: Vec<u32> = tokens.into_iter().map(|u| u as u32).collect();
|
|
||||||
println!("dataset loaded: {} tokens", tokens.len());
|
|
||||||
|
|
||||||
let seq_len = model.config.seq_len;
|
|
||||||
let mut inputs = vec![];
|
|
||||||
let mut targets = vec![];
|
|
||||||
for start_idx in (0..tokens.len()).step_by(seq_len) {
|
|
||||||
if start_idx + seq_len + 1 > tokens.len() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let tokens = &tokens[start_idx..start_idx + seq_len + 1];
|
|
||||||
let inputs_ = Tensor::new(&tokens[..seq_len], &device)?;
|
|
||||||
let targets_ = Tensor::new(&tokens[1..], &device)?;
|
|
||||||
inputs.push(inputs_);
|
|
||||||
targets.push(targets_);
|
|
||||||
if inputs.len() >= args.batch_size {
|
|
||||||
let inp = Tensor::stack(&inputs, 0)?;
|
|
||||||
let tgt = Tensor::stack(&targets, 0)?;
|
|
||||||
let logits = model.forward(&inp, 0)?;
|
|
||||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
|
||||||
println!("{}", loss.to_vec0::<f32>()?);
|
|
||||||
inputs.clear();
|
|
||||||
targets.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
|
|
||||||
use std::io::BufRead;
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
|
||||||
let config = Config::from_reader(&mut file)?;
|
|
||||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
|
||||||
let vb = weights.var_builder(&config, &device)?;
|
|
||||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
|
||||||
let model = Llama::load(vb, &cache, config)?;
|
|
||||||
|
|
||||||
let tokens = match args.pretokenized_dir {
|
|
||||||
None => {
|
None => {
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
|
let model_id = "roneneldan/TinyStories"; // TODO: Make this configurable.
|
||||||
@ -365,8 +343,20 @@ fn run_eval(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args)
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: Args) -> Result<()> {
|
fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let config_path = match &args.config {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
println!("loading the model weights from {}", args.model_id);
|
||||||
|
let api = api.model(args.model_id.clone());
|
||||||
|
api.get(&args.which_model)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let tokenizer = common_args.tokenizer()?;
|
||||||
|
|
||||||
|
let device = candle_examples::device(common_args.cpu)?;
|
||||||
|
|
||||||
let mut file = std::fs::File::open(config_path)?;
|
let mut file = std::fs::File::open(config_path)?;
|
||||||
let config = Config::from_reader(&mut file)?;
|
let config = Config::from_reader(&mut file)?;
|
||||||
@ -381,7 +371,7 @@ fn run_inference(tokenizer: Tokenizer, config_path: &std::path::PathBuf, args: A
|
|||||||
|
|
||||||
print!("{}", args.prompt);
|
print!("{}", args.prompt);
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(args.prompt, true)
|
.encode(args.prompt.clone(), true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
Reference in New Issue
Block a user