Putting back Send + Sync

This commit is contained in:
Nicolas Patry
2023-07-26 10:22:40 +00:00
parent 7c7e6ba201
commit 25a2086e8f
3 changed files with 21 additions and 13 deletions

View File

@ -103,7 +103,7 @@ pub enum Op {
} }
/// Unary ops that can be defined in user-land. /// Unary ops that can be defined in user-land.
pub trait CustomOp1 { pub trait CustomOp1: Send + Sync {
// Box<dyn> does not support const yet, so use a function to get the name. // Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str; fn name(&self) -> &'static str;

View File

@ -4,6 +4,7 @@ use candle_nn::{Embedding, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp}; use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16; use half::f16;
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN; use super::MAX_SEQ_LEN;
@ -23,13 +24,20 @@ impl TensorParallelColumnLinear {
struct TensorParallelRowLinear { struct TensorParallelRowLinear {
linear: Linear, linear: Linear,
comm: Arc<Comm>, comm: Rc<Comm>,
} }
struct AllReduce { struct AllReduce {
comm: Arc<Comm>, comm: Rc<Comm>,
} }
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Sync for AllReduce {}
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
/// But for this example purposes, this will work
unsafe impl Send for AllReduce {}
impl CustomOp1 for AllReduce { impl CustomOp1 for AllReduce {
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
"allreduce" "allreduce"
@ -60,12 +68,12 @@ impl CustomOp1 for AllReduce {
} }
} }
fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> { fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
x.custom_op1(AllReduce { comm: comm.clone() }) x.custom_op1(AllReduce { comm: comm.clone() })
} }
impl TensorParallelRowLinear { impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Arc<Comm>) -> Self { fn new(linear: Linear, comm: Rc<Comm>) -> Self {
Self { linear, comm } Self { linear, comm }
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -75,14 +83,14 @@ impl TensorParallelRowLinear {
} }
impl TensorParallelColumnLinear { impl TensorParallelColumnLinear {
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank(); let rank = comm.rank();
let size = comm.world_size(); let size = comm.world_size();
let weight = vb.get_sharded("weight", 0, rank, size)?; let weight = vb.get_sharded("weight", 0, rank, size)?;
Ok(Self::new(Linear::new(weight, None))) Ok(Self::new(Linear::new(weight, None)))
} }
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> { fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank(); let rank = comm.rank();
let size = comm.world_size(); let size = comm.world_size();
let weights: Vec<_> = prefixes let weights: Vec<_> = prefixes
@ -95,7 +103,7 @@ impl TensorParallelColumnLinear {
} }
impl TensorParallelRowLinear { impl TensorParallelRowLinear {
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank(); let rank = comm.rank();
let size = comm.world_size(); let size = comm.world_size();
let weight = vb.get_sharded("weight", 1, rank, size)?; let weight = vb.get_sharded("weight", 1, rank, size)?;
@ -338,7 +346,7 @@ impl CausalSelfAttention {
} }
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let qkv_proj = TensorParallelColumnLinear::load_multi( let qkv_proj = TensorParallelColumnLinear::load_multi(
vb.clone(), vb.clone(),
&["q_proj", "k_proj", "v_proj"], &["q_proj", "k_proj", "v_proj"],
@ -387,7 +395,7 @@ impl Mlp {
self.c_proj.forward(&x) self.c_proj.forward(&x)
} }
fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
@ -421,7 +429,7 @@ impl Block {
Ok(x) Ok(x)
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
@ -465,7 +473,7 @@ impl Llama {
logits.to_dtype(DType::F32) logits.to_dtype(DType::F32)
} }
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;

View File

@ -11,7 +11,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
} }
#[derive(Clone)] #[derive(Clone)]
#[pyclass(name = "Tensor", unsendable)] #[pyclass(name = "Tensor")]
struct PyTensor(Tensor); struct PyTensor(Tensor);
impl std::ops::Deref for PyTensor { impl std::ops::Deref for PyTensor {