Adding support for codellama in examples.

Codellama requires bf16 for now (error to convert from bf16 to f16).
Multiprocess demo not functional for it because flash-attn only supports
f16 for now.
This commit is contained in:
Nicolas Patry
2023-08-25 09:56:11 +00:00
parent afc10a3232
commit 4826a4212e
4 changed files with 91 additions and 47 deletions

View File

@ -17,7 +17,7 @@ use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::api::sync::Api;
use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use std::rc::Rc;
@ -108,6 +108,12 @@ struct Args {
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
dtype: Option<String>,
}
fn main() -> Result<()> {
@ -115,8 +121,13 @@ fn main() -> Result<()> {
let args = Args::parse();
let config = Config::config_7b();
let dtype = DType::F16;
let dtype = match args.dtype.as_deref() {
Some("f16") => DType::F16,
Some("bf16") => DType::BF16,
Some("f32") => DType::F32,
Some(dtype) => panic!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let api = Api::new()?;
@ -124,7 +135,10 @@ fn main() -> Result<()> {
.model_id
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
println!("loading the model weights from {model_id}");
let api = api.model(model_id);
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let config_filename = api.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
@ -185,7 +199,7 @@ fn main() -> Result<()> {
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let cache = model::Cache::new(&config, &device)?;
let cache = model::Cache::new(dtype, &config, &device)?;
println!("building the model");
let handles = filenames