mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Make some model cloneable. (#1125)
This commit is contained in:
@ -74,7 +74,7 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct Embedding {
|
||||
wte: E,
|
||||
}
|
||||
@ -106,7 +106,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
@ -172,7 +172,7 @@ impl RotaryEmbedding {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
fc1: Linear,
|
||||
@ -199,7 +199,7 @@ impl Module for MLP {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalLMHead {
|
||||
ln: candle_nn::LayerNorm,
|
||||
linear: Linear,
|
||||
@ -221,7 +221,7 @@ impl Module for CausalLMHead {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MHA {
|
||||
wqkv: Linear,
|
||||
@ -310,7 +310,7 @@ impl MHA {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct ParallelBlock {
|
||||
ln: candle_nn::LayerNorm,
|
||||
mixer: MHA,
|
||||
@ -345,7 +345,7 @@ impl ParallelBlock {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MixFormerSequentialForCausalLM {
|
||||
embedding: Embedding,
|
||||
blocks: Vec<ParallelBlock>,
|
||||
|
Reference in New Issue
Block a user