diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 83b382cd..525383b2 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -103,7 +103,7 @@ pub enum Op { } /// Unary ops that can be defined in user-land. -pub trait CustomOp1 { +pub trait CustomOp1: Send + Sync { // Box does not support const yet, so use a function to get the name. fn name(&self) -> &'static str; diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index becaa879..bcf6ed2b 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -4,6 +4,7 @@ use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; +use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; @@ -23,13 +24,20 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Arc, + comm: Rc, } struct AllReduce { - comm: Arc, + comm: Rc, } +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Sync for AllReduce {} +/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html +/// But for this example purposes, this will work +unsafe impl Send for AllReduce {} + impl CustomOp1 for AllReduce { fn name(&self) -> &'static str { "allreduce" @@ -60,12 +68,12 @@ impl CustomOp1 for AllReduce { } } -fn all_reduce_sum(x: &Tensor, comm: &Arc) -> Result { +fn all_reduce_sum(x: &Tensor, comm: &Rc) -> Result { x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { - fn new(linear: Linear, comm: Arc) -> Self { + fn new(linear: Linear, comm: Rc) -> Self { Self { linear, comm } } fn forward(&self, x: &Tensor) -> Result { @@ -75,14 +83,14 @@ impl TensorParallelRowLinear { } impl TensorParallelColumnLinear { - fn load(vb: VarBuilder, comm: Arc) -> Result { + fn load(vb: VarBuilder, comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 0, rank, size)?; Ok(Self::new(Linear::new(weight, None))) } - fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc) -> Result { + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weights: Vec<_> = prefixes @@ -95,7 +103,7 @@ impl TensorParallelColumnLinear { } impl TensorParallelRowLinear { - fn load(vb: VarBuilder, comm: Arc) -> Result { + fn load(vb: VarBuilder, comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 1, rank, size)?; @@ -338,7 +346,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -387,7 +395,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, _cfg: &Config, comm: Arc) -> Result { + fn load(vb: VarBuilder, _cfg: &Config, comm: Rc) -> Result { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -421,7 +429,7 @@ impl Block { Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; @@ -465,7 +473,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 6e206688..136f8a4f 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -11,7 +11,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr { } #[derive(Clone)] -#[pyclass(name = "Tensor", unsendable)] +#[pyclass(name = "Tensor")] struct PyTensor(Tensor); impl std::ops::Deref for PyTensor {