From 902d0b91664df1e9074cc365de5eba6b578d6692 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 18 Oct 2023 21:55:46 +0100 Subject: [PATCH] More model cloning. (#1126) * More model cloning. * More cloning on quantized models. --- .../src/models/quantized_mistral.rs | 10 +++++----- .../src/models/quantized_mixformer.rs | 14 +++++++------- candle-transformers/src/models/quantized_mpt.rs | 8 ++++---- candle-transformers/src/quantized_nn.rs | 6 +++--- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/candle-transformers/src/models/quantized_mistral.rs b/candle-transformers/src/models/quantized_mistral.rs index 00c80209..9e306c67 100644 --- a/candle-transformers/src/models/quantized_mistral.rs +++ b/candle-transformers/src/models/quantized_mistral.rs @@ -6,7 +6,7 @@ use std::sync::Arc; pub use crate::models::mistral::Config; -#[derive(Debug)] +#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, @@ -57,7 +57,7 @@ impl RotaryEmbedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { gate_proj: Linear, @@ -90,7 +90,7 @@ impl Module for MLP { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Attention { q_proj: Linear, k_proj: Linear, @@ -200,7 +200,7 @@ impl Attention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct DecoderLayer { self_attn: Attention, mlp: MLP, @@ -243,7 +243,7 @@ impl DecoderLayer { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { embed_tokens: Embedding, layers: Vec, diff --git a/candle-transformers/src/models/quantized_mixformer.rs b/candle-transformers/src/models/quantized_mixformer.rs index 23eeb0ac..f11f2036 100644 --- a/candle-transformers/src/models/quantized_mixformer.rs +++ b/candle-transformers/src/models/quantized_mixformer.rs @@ -7,7 +7,7 @@ pub use crate::models::mixformer::Config; const MAX_SEQ_LEN: usize = 4096; -#[derive(Debug)] +#[derive(Debug, Clone)] struct Embedding { wte: crate::quantized_nn::Embedding, } @@ -39,7 +39,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result Ok(m) } -#[derive(Debug)] +#[derive(Debug, Clone)] struct RotaryEmbedding { sin: Tensor, cos: Tensor, @@ -105,7 +105,7 @@ impl RotaryEmbedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { fc1: Linear, @@ -132,7 +132,7 @@ impl Module for MLP { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct CausalLMHead { ln: candle_nn::LayerNorm, linear: Linear, @@ -154,7 +154,7 @@ impl Module for CausalLMHead { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MHA { wqkv: Linear, @@ -243,7 +243,7 @@ impl MHA { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct ParallelBlock { ln: candle_nn::LayerNorm, mixer: MHA, @@ -278,7 +278,7 @@ impl ParallelBlock { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct MixFormerSequentialForCausalLM { embedding: Embedding, blocks: Vec, diff --git a/candle-transformers/src/models/quantized_mpt.rs b/candle-transformers/src/models/quantized_mpt.rs index 7586e4c0..70a9e125 100644 --- a/candle-transformers/src/models/quantized_mpt.rs +++ b/candle-transformers/src/models/quantized_mpt.rs @@ -7,7 +7,7 @@ use candle_nn::LayerNorm; pub use super::mpt::Config; -#[derive(Debug)] +#[derive(Debug, Clone)] struct GroupedQueryAttention { wqkv: Linear, out_proj: Linear, @@ -101,7 +101,7 @@ impl GroupedQueryAttention { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct Ffn { up_proj: Linear, down_proj: Linear, @@ -122,7 +122,7 @@ impl Module for Ffn { } } -#[derive(Debug)] +#[derive(Debug, Clone)] struct MPTBlock { norm1: LayerNorm, // Do we need the low-precision variant? attn: GroupedQueryAttention, @@ -155,7 +155,7 @@ impl MPTBlock { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Model { wte: Embedding, blocks: Vec, diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index d71c3b60..2941c3f0 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -2,7 +2,7 @@ use crate::models::with_tracing::QMatMul; use crate::quantized_var_builder::VarBuilder; use candle::{Module, Result, Tensor}; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Embedding { inner: candle_nn::Embedding, span: tracing::Span, @@ -28,7 +28,7 @@ impl Module for Embedding { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Linear { weight: QMatMul, bias: Option, @@ -69,7 +69,7 @@ pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result