mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Putting back Send + Sync
This commit is contained in:
@ -103,7 +103,7 @@ pub enum Op {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
/// Unary ops that can be defined in user-land.
|
||||||
pub trait CustomOp1 {
|
pub trait CustomOp1: Send + Sync {
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||||
fn name(&self) -> &'static str;
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ use candle_nn::{Embedding, Linear, VarBuilder};
|
|||||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::rc::Rc;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::MAX_SEQ_LEN;
|
use super::MAX_SEQ_LEN;
|
||||||
@ -23,13 +24,20 @@ impl TensorParallelColumnLinear {
|
|||||||
|
|
||||||
struct TensorParallelRowLinear {
|
struct TensorParallelRowLinear {
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
comm: Arc<Comm>,
|
comm: Rc<Comm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AllReduce {
|
struct AllReduce {
|
||||||
comm: Arc<Comm>,
|
comm: Rc<Comm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||||
|
/// But for this example purposes, this will work
|
||||||
|
unsafe impl Sync for AllReduce {}
|
||||||
|
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||||
|
/// But for this example purposes, this will work
|
||||||
|
unsafe impl Send for AllReduce {}
|
||||||
|
|
||||||
impl CustomOp1 for AllReduce {
|
impl CustomOp1 for AllReduce {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str {
|
||||||
"allreduce"
|
"allreduce"
|
||||||
@ -60,12 +68,12 @@ impl CustomOp1 for AllReduce {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> {
|
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||||
x.custom_op1(AllReduce { comm: comm.clone() })
|
x.custom_op1(AllReduce { comm: comm.clone() })
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TensorParallelRowLinear {
|
impl TensorParallelRowLinear {
|
||||||
fn new(linear: Linear, comm: Arc<Comm>) -> Self {
|
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||||
Self { linear, comm }
|
Self { linear, comm }
|
||||||
}
|
}
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -75,14 +83,14 @@ impl TensorParallelRowLinear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TensorParallelColumnLinear {
|
impl TensorParallelColumnLinear {
|
||||||
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||||
let rank = comm.rank();
|
let rank = comm.rank();
|
||||||
let size = comm.world_size();
|
let size = comm.world_size();
|
||||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||||
Ok(Self::new(Linear::new(weight, None)))
|
Ok(Self::new(Linear::new(weight, None)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> {
|
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
|
||||||
let rank = comm.rank();
|
let rank = comm.rank();
|
||||||
let size = comm.world_size();
|
let size = comm.world_size();
|
||||||
let weights: Vec<_> = prefixes
|
let weights: Vec<_> = prefixes
|
||||||
@ -95,7 +103,7 @@ impl TensorParallelColumnLinear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TensorParallelRowLinear {
|
impl TensorParallelRowLinear {
|
||||||
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||||
let rank = comm.rank();
|
let rank = comm.rank();
|
||||||
let size = comm.world_size();
|
let size = comm.world_size();
|
||||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||||
@ -338,7 +346,7 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||||
let qkv_proj = TensorParallelColumnLinear::load_multi(
|
let qkv_proj = TensorParallelColumnLinear::load_multi(
|
||||||
vb.clone(),
|
vb.clone(),
|
||||||
&["q_proj", "k_proj", "v_proj"],
|
&["q_proj", "k_proj", "v_proj"],
|
||||||
@ -387,7 +395,7 @@ impl Mlp {
|
|||||||
self.c_proj.forward(&x)
|
self.c_proj.forward(&x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
|
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
|
||||||
@ -421,7 +429,7 @@ impl Block {
|
|||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
||||||
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
|
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
|
||||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||||
@ -465,7 +473,7 @@ impl Llama {
|
|||||||
logits.to_dtype(DType::F32)
|
logits.to_dtype(DType::F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||||
|
@ -11,7 +11,7 @@ pub fn wrap_err(err: ::candle::Error) -> PyErr {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
#[pyclass(name = "Tensor", unsendable)]
|
#[pyclass(name = "Tensor")]
|
||||||
struct PyTensor(Tensor);
|
struct PyTensor(Tensor);
|
||||||
|
|
||||||
impl std::ops::Deref for PyTensor {
|
impl std::ops::Deref for PyTensor {
|
||||||
|
Reference in New Issue
Block a user