Making multiprocess require flash-attn.

This commit is contained in:
Nicolas Patry
2023-07-28 07:52:24 +00:00
parent 50d8273ae4
commit 97181a77c0
3 changed files with 48 additions and 93 deletions

View File

@ -20,7 +20,7 @@ use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
use hf_hub::{api::sync::Api, Repo, RepoType};
use hf_hub::api::sync::Api;
use std::io::Write;
use std::rc::Rc;
@ -83,10 +83,6 @@ Upon my target three fair-shining suns.
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
num_shards: usize,
@ -113,15 +109,8 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// Use f32 computations rather than f16.
#[arg(long)]
use_f32: bool,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
v2: bool,
}
fn main() -> Result<()> {
@ -130,26 +119,22 @@ fn main() -> Result<()> {
let args = Args::parse();
let config = Config::config_7b();
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
let dtype = DType::F16;
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| {
if args.v2 {
"meta-llama/Llama-2-7b-hf".to_string()
} else {
"Narsil/amall-7b".to_string()
}
});
let model_id = args
.model_id
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
println!("loading the model weights from {model_id}");
let repo = Repo::new(model_id, RepoType::Model);
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
let api = api.model(model_id);
let tokenizer_filename = api.get("tokenizer.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(&repo, rfilename)?;
let filename = api.get(rfilename)?;
filenames.push(filename);
}
@ -203,7 +188,7 @@ fn main() -> Result<()> {
println!("Rank {rank:?} spawned");
let device = Device::new_cuda(i)?;
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
let cache = model::Cache::new(&config, &device)?;
println!("building the model");
let handles = filenames
@ -233,11 +218,7 @@ fn main() -> Result<()> {
let mut index_pos = 0;
for index in 0..args.sample_len {
let start_gen = std::time::Instant::now();
let context_size = if cache.use_kv_cache && index > 0 {
1
} else {
tokens.len()
};
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?;