mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
More model cloning. (#1126)
* More model cloning. * More cloning on quantized models.
This commit is contained in:
@ -6,7 +6,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
pub use crate::models::mistral::Config;
|
pub use crate::models::mistral::Config;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct RotaryEmbedding {
|
struct RotaryEmbedding {
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
@ -57,7 +57,7 @@ impl RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
gate_proj: Linear,
|
gate_proj: Linear,
|
||||||
@ -90,7 +90,7 @@ impl Module for MLP {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct Attention {
|
struct Attention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
@ -200,7 +200,7 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecoderLayer {
|
struct DecoderLayer {
|
||||||
self_attn: Attention,
|
self_attn: Attention,
|
||||||
mlp: MLP,
|
mlp: MLP,
|
||||||
@ -243,7 +243,7 @@ impl DecoderLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
embed_tokens: Embedding,
|
embed_tokens: Embedding,
|
||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<DecoderLayer>,
|
||||||
|
@ -7,7 +7,7 @@ pub use crate::models::mixformer::Config;
|
|||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct Embedding {
|
struct Embedding {
|
||||||
wte: crate::quantized_nn::Embedding,
|
wte: crate::quantized_nn::Embedding,
|
||||||
}
|
}
|
||||||
@ -39,7 +39,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct RotaryEmbedding {
|
struct RotaryEmbedding {
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
@ -105,7 +105,7 @@ impl RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
fc1: Linear,
|
fc1: Linear,
|
||||||
@ -132,7 +132,7 @@ impl Module for MLP {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct CausalLMHead {
|
struct CausalLMHead {
|
||||||
ln: candle_nn::LayerNorm,
|
ln: candle_nn::LayerNorm,
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
@ -154,7 +154,7 @@ impl Module for CausalLMHead {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MHA {
|
struct MHA {
|
||||||
wqkv: Linear,
|
wqkv: Linear,
|
||||||
@ -243,7 +243,7 @@ impl MHA {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct ParallelBlock {
|
struct ParallelBlock {
|
||||||
ln: candle_nn::LayerNorm,
|
ln: candle_nn::LayerNorm,
|
||||||
mixer: MHA,
|
mixer: MHA,
|
||||||
@ -278,7 +278,7 @@ impl ParallelBlock {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MixFormerSequentialForCausalLM {
|
pub struct MixFormerSequentialForCausalLM {
|
||||||
embedding: Embedding,
|
embedding: Embedding,
|
||||||
blocks: Vec<ParallelBlock>,
|
blocks: Vec<ParallelBlock>,
|
||||||
|
@ -7,7 +7,7 @@ use candle_nn::LayerNorm;
|
|||||||
|
|
||||||
pub use super::mpt::Config;
|
pub use super::mpt::Config;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct GroupedQueryAttention {
|
struct GroupedQueryAttention {
|
||||||
wqkv: Linear,
|
wqkv: Linear,
|
||||||
out_proj: Linear,
|
out_proj: Linear,
|
||||||
@ -101,7 +101,7 @@ impl GroupedQueryAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct Ffn {
|
struct Ffn {
|
||||||
up_proj: Linear,
|
up_proj: Linear,
|
||||||
down_proj: Linear,
|
down_proj: Linear,
|
||||||
@ -122,7 +122,7 @@ impl Module for Ffn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct MPTBlock {
|
struct MPTBlock {
|
||||||
norm1: LayerNorm, // Do we need the low-precision variant?
|
norm1: LayerNorm, // Do we need the low-precision variant?
|
||||||
attn: GroupedQueryAttention,
|
attn: GroupedQueryAttention,
|
||||||
@ -155,7 +155,7 @@ impl MPTBlock {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
wte: Embedding,
|
wte: Embedding,
|
||||||
blocks: Vec<MPTBlock>,
|
blocks: Vec<MPTBlock>,
|
||||||
|
@ -2,7 +2,7 @@ use crate::models::with_tracing::QMatMul;
|
|||||||
use crate::quantized_var_builder::VarBuilder;
|
use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{Module, Result, Tensor};
|
use candle::{Module, Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Embedding {
|
pub struct Embedding {
|
||||||
inner: candle_nn::Embedding,
|
inner: candle_nn::Embedding,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -28,7 +28,7 @@ impl Module for Embedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Linear {
|
pub struct Linear {
|
||||||
weight: QMatMul,
|
weight: QMatMul,
|
||||||
bias: Option<Tensor>,
|
bias: Option<Tensor>,
|
||||||
@ -69,7 +69,7 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<L
|
|||||||
Ok(Linear { weight, bias: None })
|
Ok(Linear { weight, bias: None })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RmsNorm {
|
pub struct RmsNorm {
|
||||||
inner: candle_nn::RmsNorm,
|
inner: candle_nn::RmsNorm,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
Reference in New Issue
Block a user