mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Adding support for codellama in examples.
Codellama requires bf16 for now (error to convert from bf16 to f16). Multiprocess demo not functional for it because flash-attn only supports f16 for now.
This commit is contained in:
@ -17,7 +17,7 @@ use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
use std::rc::Rc;
|
||||
|
||||
@ -108,6 +108,12 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -115,8 +121,13 @@ fn main() -> Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let config = Config::config_7b();
|
||||
let dtype = DType::F16;
|
||||
let dtype = match args.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => panic!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
|
||||
@ -124,7 +135,10 @@ fn main() -> Result<()> {
|
||||
.model_id
|
||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||
println!("loading the model weights from {model_id}");
|
||||
let api = api.model(model_id);
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let config_filename = api.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
@ -185,7 +199,7 @@ fn main() -> Result<()> {
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(i)?;
|
||||
let cache = model::Cache::new(&config, &device)?;
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
|
Reference in New Issue
Block a user