Make more models cloneable. (#1203)

This commit is contained in:
Laurent Mazare
2023-10-28 08:43:08 +02:00
committed by GitHub
parent ef33df7ae2
commit 612f5b8156
3 changed files with 26 additions and 26 deletions

View File

@ -7,7 +7,7 @@ use std::sync::Arc;
pub use crate::models::stable_lm::Config; pub use crate::models::stable_lm::Config;
use crate::models::stable_lm::RotaryEmbedding; use crate::models::stable_lm::RotaryEmbedding;
#[derive(Debug)] #[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
struct MLP { struct MLP {
gate_proj: Linear, gate_proj: Linear,
@ -43,7 +43,7 @@ impl Module for MLP {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct Attention { struct Attention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -168,7 +168,7 @@ impl Attention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct DecoderLayer { struct DecoderLayer {
self_attn: Attention, self_attn: Attention,
mlp: MLP, mlp: MLP,
@ -213,7 +213,7 @@ impl DecoderLayer {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Model { pub struct Model {
embed_tokens: Embedding, embed_tokens: Embedding,
layers: Vec<DecoderLayer>, layers: Vec<DecoderLayer>,

View File

@ -93,7 +93,7 @@ impl Default for Config {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerNorm { struct T5LayerNorm {
weight: Tensor, weight: Tensor,
variance_epsilon: f64, variance_epsilon: f64,
@ -125,7 +125,7 @@ impl Module for T5LayerNorm {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5DenseActDense { struct T5DenseActDense {
wi: QMatMul, wi: QMatMul,
wo: QMatMul, wo: QMatMul,
@ -156,7 +156,7 @@ impl Module for T5DenseActDense {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5DenseGatedActDense { struct T5DenseGatedActDense {
wi_0: QMatMul, wi_0: QMatMul,
wi_1: QMatMul, wi_1: QMatMul,
@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerFF { struct T5LayerFF {
dense_act: Option<T5DenseActDense>, dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>, gated_dense_act: Option<T5DenseGatedActDense>,
@ -236,7 +236,7 @@ impl Module for T5LayerFF {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5Attention { struct T5Attention {
q: QMatMul, q: QMatMul,
k: QMatMul, k: QMatMul,
@ -431,7 +431,7 @@ impl T5Attention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerSelfAttention { struct T5LayerSelfAttention {
self_attention: T5Attention, self_attention: T5Attention,
layer_norm: T5LayerNorm, layer_norm: T5LayerNorm,
@ -470,7 +470,7 @@ impl T5LayerSelfAttention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerCrossAttention { struct T5LayerCrossAttention {
cross_attention: T5Attention, cross_attention: T5Attention,
layer_norm: T5LayerNorm, layer_norm: T5LayerNorm,
@ -512,7 +512,7 @@ impl T5LayerCrossAttention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5Block { struct T5Block {
self_attn: T5LayerSelfAttention, self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>, cross_attn: Option<T5LayerCrossAttention>,
@ -583,7 +583,7 @@ impl T5Block {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5Stack { struct T5Stack {
block: Vec<T5Block>, block: Vec<T5Block>,
shared: Arc<Embedding>, shared: Arc<Embedding>,
@ -633,7 +633,7 @@ impl T5Stack {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct T5EncoderModel { pub struct T5EncoderModel {
encoder: T5Stack, encoder: T5Stack,
device: Device, device: Device,
@ -666,7 +666,7 @@ impl T5EncoderModel {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct T5ForConditionalGeneration { pub struct T5ForConditionalGeneration {
encoder: T5Stack, encoder: T5Stack,
decoder: T5Stack, decoder: T5Stack,

View File

@ -118,7 +118,7 @@ impl Config {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerNorm { struct T5LayerNorm {
weight: Tensor, weight: Tensor,
variance_epsilon: f64, variance_epsilon: f64,
@ -150,7 +150,7 @@ impl Module for T5LayerNorm {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5DenseActDense { struct T5DenseActDense {
wi: Linear, wi: Linear,
wo: Linear, wo: Linear,
@ -181,7 +181,7 @@ impl Module for T5DenseActDense {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5DenseGatedActDense { struct T5DenseGatedActDense {
wi_0: Linear, wi_0: Linear,
wi_1: Linear, wi_1: Linear,
@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerFF { struct T5LayerFF {
dense_act: Option<T5DenseActDense>, dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>, gated_dense_act: Option<T5DenseGatedActDense>,
@ -261,7 +261,7 @@ impl Module for T5LayerFF {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5Attention { struct T5Attention {
q: Linear, q: Linear,
k: Linear, k: Linear,
@ -456,7 +456,7 @@ impl T5Attention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerSelfAttention { struct T5LayerSelfAttention {
self_attention: T5Attention, self_attention: T5Attention,
layer_norm: T5LayerNorm, layer_norm: T5LayerNorm,
@ -495,7 +495,7 @@ impl T5LayerSelfAttention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5LayerCrossAttention { struct T5LayerCrossAttention {
cross_attention: T5Attention, cross_attention: T5Attention,
layer_norm: T5LayerNorm, layer_norm: T5LayerNorm,
@ -537,7 +537,7 @@ impl T5LayerCrossAttention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5Block { struct T5Block {
self_attn: T5LayerSelfAttention, self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>, cross_attn: Option<T5LayerCrossAttention>,
@ -608,7 +608,7 @@ impl T5Block {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct T5Stack { struct T5Stack {
block: Vec<T5Block>, block: Vec<T5Block>,
shared: Arc<Embedding>, shared: Arc<Embedding>,
@ -658,7 +658,7 @@ impl T5Stack {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct T5EncoderModel { pub struct T5EncoderModel {
encoder: T5Stack, encoder: T5Stack,
device: Device, device: Device,
@ -691,7 +691,7 @@ impl T5EncoderModel {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct T5ForConditionalGeneration { pub struct T5ForConditionalGeneration {
encoder: T5Stack, encoder: T5Stack,
decoder: T5Stack, decoder: T5Stack,