mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
make llama derive clone (#1648)
Co-authored-by: danielclough <danielclough@users.noreply.github.com>
This commit is contained in:
@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
|
|||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct LlamaConfig {
|
pub struct LlamaConfig {
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
@ -40,6 +40,7 @@ impl LlamaConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
@ -82,7 +83,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
pub use_kv_cache: bool,
|
pub use_kv_cache: bool,
|
||||||
@ -136,6 +137,7 @@ impl Cache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
inner: candle_nn::RmsNorm,
|
inner: candle_nn::RmsNorm,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -154,6 +156,7 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct CausalSelfAttention {
|
struct CausalSelfAttention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
@ -314,6 +317,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
c_fc1: Linear,
|
c_fc1: Linear,
|
||||||
c_fc2: Linear,
|
c_fc2: Linear,
|
||||||
@ -344,6 +348,7 @@ impl Mlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct Block {
|
struct Block {
|
||||||
rms_1: RmsNorm,
|
rms_1: RmsNorm,
|
||||||
attn: CausalSelfAttention,
|
attn: CausalSelfAttention,
|
||||||
@ -383,6 +388,7 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct Llama {
|
pub struct Llama {
|
||||||
wte: Embedding,
|
wte: Embedding,
|
||||||
blocks: Vec<Block>,
|
blocks: Vec<Block>,
|
||||||
|
Reference in New Issue
Block a user