From 58cc896e692936e36f3c68cf33ce949a7298bd4d Mon Sep 17 00:00:00 2001 From: Daniel Clough <9276072+danielclough@users.noreply.github.com> Date: Sun, 4 Feb 2024 02:56:03 -0800 Subject: [PATCH] make llama derive clone (#1648) Co-authored-by: danielclough --- candle-transformers/src/models/llama.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index f003866a..7a920cb8 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex}; pub const MAX_SEQ_LEN: usize = 4096; -#[derive(Deserialize)] +#[derive(Debug, Clone, Deserialize)] pub struct LlamaConfig { pub hidden_size: usize, pub intermediate_size: usize, @@ -40,6 +40,7 @@ impl LlamaConfig { } } +#[derive(Debug, Clone)] pub struct Config { pub hidden_size: usize, pub intermediate_size: usize, @@ -82,7 +83,7 @@ impl Config { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Cache { masks: Arc>>, pub use_kv_cache: bool, @@ -136,6 +137,7 @@ impl Cache { } } +#[derive(Debug, Clone)] struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, @@ -154,6 +156,7 @@ impl RmsNorm { } } +#[derive(Debug, Clone)] struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -314,6 +317,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result Ok(m) } +#[derive(Debug, Clone)] struct Mlp { c_fc1: Linear, c_fc2: Linear, @@ -344,6 +348,7 @@ impl Mlp { } } +#[derive(Debug, Clone)] struct Block { rms_1: RmsNorm, attn: CausalSelfAttention, @@ -383,6 +388,7 @@ impl Block { } } +#[derive(Debug, Clone)] pub struct Llama { wte: Embedding, blocks: Vec,