make llama derive clone (#1648)

Co-authored-by: danielclough <danielclough@users.noreply.github.com>
This commit is contained in:
Daniel Clough
2024-02-04 02:56:03 -08:00
committed by GitHub
parent 5cdd84e0f6
commit 58cc896e69

View File

@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Deserialize)]
#[derive(Debug, Clone, Deserialize)]
pub struct LlamaConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
@ -40,6 +40,7 @@ impl LlamaConfig {
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub hidden_size: usize,
pub intermediate_size: usize,
@ -82,7 +83,7 @@ impl Config {
}
}
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
@ -136,6 +137,7 @@ impl Cache {
}
}
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
@ -154,6 +156,7 @@ impl RmsNorm {
}
}
#[derive(Debug, Clone)]
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -314,6 +317,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
@ -344,6 +348,7 @@ impl Mlp {
}
}
#[derive(Debug, Clone)]
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
@ -383,6 +388,7 @@ impl Block {
}
}
#[derive(Debug, Clone)]
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,