Adding support for codellama in examples.

Codellama requires bf16 for now (error to convert from bf16 to f16).
Multiprocess demo not functional for it because flash-attn only supports
f16 for now.
This commit is contained in:
Nicolas Patry
2023-08-25 09:56:11 +00:00
parent afc10a3232
commit 4826a4212e
4 changed files with 91 additions and 47 deletions

View File

@ -18,7 +18,7 @@ use clap::Parser;
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::api::sync::Api;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
mod model;
@ -59,9 +59,9 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// Use f32 computations rather than f16.
/// Use different dtype than f16
#[arg(long)]
use_f32: bool,
dtype: Option<String>,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
@ -70,6 +70,9 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
v1: bool,
@ -97,7 +100,13 @@ fn main() -> Result<()> {
};
let device = candle_examples::device(args.cpu)?;
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => panic!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, cache) = match args.npy {
Some(filename) => {
let config = if args.v1 {
@ -120,7 +129,8 @@ fn main() -> Result<()> {
}
});
println!("loading the model weights from {model_id}");
let api = api.model(model_id);
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "tokenizer.json").into(),