More model cloning. (#1126)

* More model cloning.

* More cloning on quantized models.
This commit is contained in:
Laurent Mazare
2023-10-18 21:55:46 +01:00
committed by GitHub
parent 185b54a33b
commit 902d0b9166
4 changed files with 19 additions and 19 deletions

View File

@ -6,7 +6,7 @@ use std::sync::Arc;
pub use crate::models::mistral::Config; pub use crate::models::mistral::Config;
#[derive(Debug)] #[derive(Debug, Clone)]
struct RotaryEmbedding { struct RotaryEmbedding {
sin: Tensor, sin: Tensor,
cos: Tensor, cos: Tensor,
@ -57,7 +57,7 @@ impl RotaryEmbedding {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
struct MLP { struct MLP {
gate_proj: Linear, gate_proj: Linear,
@ -90,7 +90,7 @@ impl Module for MLP {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct Attention { struct Attention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -200,7 +200,7 @@ impl Attention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct DecoderLayer { struct DecoderLayer {
self_attn: Attention, self_attn: Attention,
mlp: MLP, mlp: MLP,
@ -243,7 +243,7 @@ impl DecoderLayer {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Model { pub struct Model {
embed_tokens: Embedding, embed_tokens: Embedding,
layers: Vec<DecoderLayer>, layers: Vec<DecoderLayer>,

View File

@ -7,7 +7,7 @@ pub use crate::models::mixformer::Config;
const MAX_SEQ_LEN: usize = 4096; const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug)] #[derive(Debug, Clone)]
struct Embedding { struct Embedding {
wte: crate::quantized_nn::Embedding, wte: crate::quantized_nn::Embedding,
} }
@ -39,7 +39,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct RotaryEmbedding { struct RotaryEmbedding {
sin: Tensor, sin: Tensor,
cos: Tensor, cos: Tensor,
@ -105,7 +105,7 @@ impl RotaryEmbedding {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
struct MLP { struct MLP {
fc1: Linear, fc1: Linear,
@ -132,7 +132,7 @@ impl Module for MLP {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct CausalLMHead { struct CausalLMHead {
ln: candle_nn::LayerNorm, ln: candle_nn::LayerNorm,
linear: Linear, linear: Linear,
@ -154,7 +154,7 @@ impl Module for CausalLMHead {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
struct MHA { struct MHA {
wqkv: Linear, wqkv: Linear,
@ -243,7 +243,7 @@ impl MHA {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct ParallelBlock { struct ParallelBlock {
ln: candle_nn::LayerNorm, ln: candle_nn::LayerNorm,
mixer: MHA, mixer: MHA,
@ -278,7 +278,7 @@ impl ParallelBlock {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct MixFormerSequentialForCausalLM { pub struct MixFormerSequentialForCausalLM {
embedding: Embedding, embedding: Embedding,
blocks: Vec<ParallelBlock>, blocks: Vec<ParallelBlock>,

View File

@ -7,7 +7,7 @@ use candle_nn::LayerNorm;
pub use super::mpt::Config; pub use super::mpt::Config;
#[derive(Debug)] #[derive(Debug, Clone)]
struct GroupedQueryAttention { struct GroupedQueryAttention {
wqkv: Linear, wqkv: Linear,
out_proj: Linear, out_proj: Linear,
@ -101,7 +101,7 @@ impl GroupedQueryAttention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct Ffn { struct Ffn {
up_proj: Linear, up_proj: Linear,
down_proj: Linear, down_proj: Linear,
@ -122,7 +122,7 @@ impl Module for Ffn {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct MPTBlock { struct MPTBlock {
norm1: LayerNorm, // Do we need the low-precision variant? norm1: LayerNorm, // Do we need the low-precision variant?
attn: GroupedQueryAttention, attn: GroupedQueryAttention,
@ -155,7 +155,7 @@ impl MPTBlock {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Model { pub struct Model {
wte: Embedding, wte: Embedding,
blocks: Vec<MPTBlock>, blocks: Vec<MPTBlock>,

View File

@ -2,7 +2,7 @@ use crate::models::with_tracing::QMatMul;
use crate::quantized_var_builder::VarBuilder; use crate::quantized_var_builder::VarBuilder;
use candle::{Module, Result, Tensor}; use candle::{Module, Result, Tensor};
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Embedding { pub struct Embedding {
inner: candle_nn::Embedding, inner: candle_nn::Embedding,
span: tracing::Span, span: tracing::Span,
@ -28,7 +28,7 @@ impl Module for Embedding {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Linear { pub struct Linear {
weight: QMatMul, weight: QMatMul,
bias: Option<Tensor>, bias: Option<Tensor>,
@ -69,7 +69,7 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<L
Ok(Linear { weight, bias: None }) Ok(Linear { weight, bias: None })
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct RmsNorm { pub struct RmsNorm {
inner: candle_nn::RmsNorm, inner: candle_nn::RmsNorm,
span: tracing::Span, span: tracing::Span,