Make some model cloneable. (#1125)

This commit is contained in:
Laurent Mazare
2023-10-18 19:30:47 +01:00
committed by GitHub
parent 620c94d12e
commit 185b54a33b
5 changed files with 25 additions and 20 deletions

View File

@ -40,7 +40,7 @@ impl Config {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct GroupedQueryAttention {
wqkv: Linear,
out_proj: Linear,
@ -148,7 +148,7 @@ pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct Ffn {
up_proj: Linear,
down_proj: Linear,
@ -169,7 +169,7 @@ impl Module for Ffn {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct MPTBlock {
norm1: LayerNorm, // Do we need the low-precision variant?
attn: GroupedQueryAttention,
@ -240,7 +240,7 @@ pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Model {
wte: Embedding,
blocks: Vec<MPTBlock>,