mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Multiprocess/multi-GPU support for llama 3. (#2092)
* Multiprocess/multi-GPU support for llama 3. * Modernize the mp example a bit.
This commit is contained in:
@ -2,7 +2,7 @@ use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use half::{bf16, f16};
|
||||
use serde::Deserialize;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -58,15 +58,31 @@ impl CustomOp1 for AllReduce {
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
match s.dtype() {
|
||||
DType::BF16 => {
|
||||
let s = s.as_cuda_slice::<bf16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
DType::F16 => {
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,7 +177,6 @@ impl Cache {
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
@ -197,16 +212,10 @@ struct CausalSelfAttention {
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
|
||||
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
@ -232,13 +241,16 @@ impl CausalSelfAttention {
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
@ -278,16 +290,7 @@ impl CausalSelfAttention {
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
candle_transformers::utils::repeat_kv(x, n_rep)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
|
Reference in New Issue
Block a user