mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge pull request #262 from LaurentMazare/update_multiprocess
Making multiprocess require flash-attn.
This commit is contained in:
@ -46,4 +46,4 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
|||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
required-features = ["cuda", "nccl"]
|
required-features = ["cuda", "nccl", "flash-attn"]
|
||||||
|
@ -20,7 +20,7 @@ use candle_nn::VarBuilder;
|
|||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use cudarc::driver::safe::CudaDevice;
|
use cudarc::driver::safe::CudaDevice;
|
||||||
use cudarc::nccl::safe::{Comm, Id};
|
use cudarc::nccl::safe::{Comm, Id};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::api::sync::Api;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
@ -83,10 +83,6 @@ Upon my target three fair-shining suns.
|
|||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
num_shards: usize,
|
num_shards: usize,
|
||||||
|
|
||||||
@ -113,15 +109,8 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// Use f32 computations rather than f16.
|
|
||||||
#[arg(long)]
|
|
||||||
use_f32: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
v2: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -130,26 +119,22 @@ fn main() -> Result<()> {
|
|||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
let dtype = DType::F16;
|
||||||
|
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
|
|
||||||
let model_id = args.model_id.unwrap_or_else(|| {
|
let model_id = args
|
||||||
if args.v2 {
|
.model_id
|
||||||
"meta-llama/Llama-2-7b-hf".to_string()
|
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||||
} else {
|
|
||||||
"Narsil/amall-7b".to_string()
|
|
||||||
}
|
|
||||||
});
|
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let repo = Repo::new(model_id, RepoType::Model);
|
let api = api.model(model_id);
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let mut filenames = vec![];
|
let mut filenames = vec![];
|
||||||
for rfilename in [
|
for rfilename in [
|
||||||
"model-00001-of-00002.safetensors",
|
"model-00001-of-00002.safetensors",
|
||||||
"model-00002-of-00002.safetensors",
|
"model-00002-of-00002.safetensors",
|
||||||
] {
|
] {
|
||||||
let filename = api.get(&repo, rfilename)?;
|
let filename = api.get(rfilename)?;
|
||||||
filenames.push(filename);
|
filenames.push(filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,7 +188,7 @@ fn main() -> Result<()> {
|
|||||||
println!("Rank {rank:?} spawned");
|
println!("Rank {rank:?} spawned");
|
||||||
|
|
||||||
let device = Device::new_cuda(i)?;
|
let device = Device::new_cuda(i)?;
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
|
let cache = model::Cache::new(&config, &device)?;
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let handles = filenames
|
let handles = filenames
|
||||||
@ -233,11 +218,7 @@ fn main() -> Result<()> {
|
|||||||
let mut index_pos = 0;
|
let mut index_pos = 0;
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let context_size = if cache.use_kv_cache && index > 0 {
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
1
|
|
||||||
} else {
|
|
||||||
tokens.len()
|
|
||||||
};
|
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let logits = llama.forward(&input, index_pos)?;
|
let logits = llama.forward(&input, index_pos)?;
|
||||||
|
@ -3,7 +3,6 @@ use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shap
|
|||||||
use candle_nn::{Embedding, Linear, VarBuilder};
|
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
@ -137,17 +136,14 @@ impl Config {
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
|
||||||
pub use_kv_cache: bool,
|
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
device: Device,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> {
|
pub fn new(config: &Config, device: &Device) -> Result<Self> {
|
||||||
// precompute freqs_cis
|
// precompute freqs_cis
|
||||||
let n_elem = config.n_embd / config.n_head;
|
let n_elem = config.n_embd / config.n_head;
|
||||||
let theta: Vec<_> = (0..n_elem)
|
let theta: Vec<_> = (0..n_elem)
|
||||||
@ -162,31 +158,14 @@ impl Cache {
|
|||||||
// This is different from the paper, see:
|
// This is different from the paper, see:
|
||||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||||
let cos = idx_theta.cos()?;
|
let cos = idx_theta.cos()?.to_dtype(DType::F16)?;
|
||||||
let sin = idx_theta.sin()?;
|
let sin = idx_theta.sin()?.to_dtype(DType::F16)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
use_kv_cache,
|
|
||||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||||
device: device.clone(),
|
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
|
||||||
let mut masks = self.masks.lock().unwrap();
|
|
||||||
if let Some(mask) = masks.get(&t) {
|
|
||||||
Ok(mask.clone())
|
|
||||||
} else {
|
|
||||||
let mask: Vec<_> = (0..t)
|
|
||||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
|
||||||
masks.insert(t, mask.clone());
|
|
||||||
Ok(mask)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||||
@ -260,7 +239,6 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
let x_dtype = x.dtype();
|
|
||||||
let (b_sz, seq_len, _) = x.shape().dims3()?;
|
let (b_sz, seq_len, _) = x.shape().dims3()?;
|
||||||
|
|
||||||
let qkv = self.qkv_proj.forward(x)?;
|
let qkv = self.qkv_proj.forward(x)?;
|
||||||
@ -282,51 +260,46 @@ impl CausalSelfAttention {
|
|||||||
|
|
||||||
let q = q
|
let q = q
|
||||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?;
|
||||||
.to_dtype(DType::F32)?;
|
|
||||||
let k = k
|
let k = k
|
||||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?;
|
||||||
.to_dtype(DType::F32)?;
|
|
||||||
let mut v = v
|
let mut v = v
|
||||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?;
|
||||||
.to_dtype(DType::F32)?;
|
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
let mut cache = self.cache.kvs.lock().unwrap();
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
let k_seq_len = k.dims()[1];
|
||||||
let k_seq_len = k.dims()[1];
|
if k_seq_len > MAX_SEQ_LEN {
|
||||||
if k_seq_len > MAX_SEQ_LEN {
|
k = k
|
||||||
k = k
|
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
.contiguous()?
|
||||||
.contiguous()?
|
}
|
||||||
}
|
let v_seq_len = v.dims()[1];
|
||||||
let v_seq_len = v.dims()[1];
|
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
v = v
|
||||||
v = v
|
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
.contiguous()?
|
||||||
.contiguous()?
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
|
||||||
}
|
}
|
||||||
|
cache[block_idx] = Some((k.clone(), v.clone()));
|
||||||
|
|
||||||
let k = self.repeat_kv(k)?;
|
let k = self.repeat_kv(k)?;
|
||||||
let v = self.repeat_kv(v)?;
|
let v = self.repeat_kv(v)?;
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let q = q.transpose(1, 2)?;
|
||||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
let k = k.transpose(1, 2)?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let v = v.transpose(1, 2)?;
|
||||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
|
let y =
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
let y = y.to_dtype(x_dtype)?;
|
|
||||||
let y = self.o_proj.forward(&y)?;
|
let y = self.o_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
@ -363,13 +336,6 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
|
||||||
let shape = mask.shape();
|
|
||||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
|
||||||
let m = mask.where_cond(&on_true, on_false)?;
|
|
||||||
Ok(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
c_fc1: TensorParallelColumnLinear,
|
c_fc1: TensorParallelColumnLinear,
|
||||||
c_fc2: TensorParallelColumnLinear,
|
c_fc2: TensorParallelColumnLinear,
|
||||||
|
Reference in New Issue
Block a user