mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Making multiprocess require flash-attn.
This commit is contained in:
@ -20,7 +20,7 @@ use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::io::Write;
|
||||
use std::rc::Rc;
|
||||
|
||||
@ -83,10 +83,6 @@ Upon my target three fair-shining suns.
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
num_shards: usize,
|
||||
|
||||
@ -113,15 +109,8 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Use f32 computations rather than f16.
|
||||
#[arg(long)]
|
||||
use_f32: bool,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
v2: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -130,26 +119,22 @@ fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let config = Config::config_7b();
|
||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||
let dtype = DType::F16;
|
||||
|
||||
let api = Api::new()?;
|
||||
|
||||
let model_id = args.model_id.unwrap_or_else(|| {
|
||||
if args.v2 {
|
||||
"meta-llama/Llama-2-7b-hf".to_string()
|
||||
} else {
|
||||
"Narsil/amall-7b".to_string()
|
||||
}
|
||||
});
|
||||
let model_id = args
|
||||
.model_id
|
||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||
println!("loading the model weights from {model_id}");
|
||||
let repo = Repo::new(model_id, RepoType::Model);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
@ -203,7 +188,7 @@ fn main() -> Result<()> {
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(i)?;
|
||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
|
||||
let cache = model::Cache::new(&config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
let handles = filenames
|
||||
@ -233,11 +218,7 @@ fn main() -> Result<()> {
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if cache.use_kv_cache && index > 0 {
|
||||
1
|
||||
} else {
|
||||
tokens.len()
|
||||
};
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, index_pos)?;
|
||||
|
Reference in New Issue
Block a user