mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Adding support for codellama in examples.
Codellama requires bf16 for now (error to convert from bf16 to f16). Multiprocess demo not functional for it because flash-attn only supports f16 for now.
This commit is contained in:
@ -18,7 +18,7 @@ use clap::Parser;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
mod model;
|
||||
@ -59,9 +59,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use f32 computations rather than f16.
|
||||
/// Use different dtype than f16
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
dtype: Option<String>,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
@ -70,6 +70,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v1: bool,
|
||||
|
||||
@ -97,7 +100,13 @@ fn main() -> Result<()> {
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
let dtype = match args.dtype.as_deref() {
|
||||
Some("f16") => DType::F16,
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => panic!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
||||
Some(filename) => {
|
||||
let config = if args.v1 {
|
||||
@ -120,7 +129,8 @@ fn main() -> Result<()> {
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let api = api.model(model_id);
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
|
||||
let tokenizer_filename = match &args.local_weights {
|
||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||
|
Reference in New Issue
Block a user