mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Make some model cloneable. (#1125)
This commit is contained in:
@ -40,7 +40,7 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct GroupedQueryAttention {
|
||||
wqkv: Linear,
|
||||
out_proj: Linear,
|
||||
@ -148,7 +148,7 @@ pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct Ffn {
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
@ -169,7 +169,7 @@ impl Module for Ffn {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct MPTBlock {
|
||||
norm1: LayerNorm, // Do we need the low-precision variant?
|
||||
attn: GroupedQueryAttention,
|
||||
@ -240,7 +240,7 @@ pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
||||
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
wte: Embedding,
|
||||
blocks: Vec<MPTBlock>,
|
||||
|
Reference in New Issue
Block a user