mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Rename the candle crate to candle-core (#301)
* Rename to candle-core. * More candle-core renaming.
This commit is contained in:
@ -106,7 +106,7 @@ impl TensorParallelRowLinear {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||
Ok(Self::new(Linear::new(weight, None), comm.clone()))
|
||||
Ok(Self::new(Linear::new(weight, None), comm))
|
||||
}
|
||||
}
|
||||
|
||||
@ -296,8 +296,8 @@ impl CausalSelfAttention {
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
let y =
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, seq_len > 1)?.transpose(1, 2)?;
|
||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
@ -363,7 +363,7 @@ impl Mlp {
|
||||
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
}
|
||||
}
|
||||
@ -396,7 +396,7 @@ impl Block {
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm)?;
|
||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
|
||||
|
Reference in New Issue
Block a user