mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Make more models cloneable. (#1203)
This commit is contained in:
@ -118,7 +118,7 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerNorm {
|
||||
weight: Tensor,
|
||||
variance_epsilon: f64,
|
||||
@ -150,7 +150,7 @@ impl Module for T5LayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseActDense {
|
||||
wi: Linear,
|
||||
wo: Linear,
|
||||
@ -181,7 +181,7 @@ impl Module for T5DenseActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseGatedActDense {
|
||||
wi_0: Linear,
|
||||
wi_1: Linear,
|
||||
@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerFF {
|
||||
dense_act: Option<T5DenseActDense>,
|
||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||
@ -261,7 +261,7 @@ impl Module for T5LayerFF {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Attention {
|
||||
q: Linear,
|
||||
k: Linear,
|
||||
@ -456,7 +456,7 @@ impl T5Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -495,7 +495,7 @@ impl T5LayerSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerCrossAttention {
|
||||
cross_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -537,7 +537,7 @@ impl T5LayerCrossAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Block {
|
||||
self_attn: T5LayerSelfAttention,
|
||||
cross_attn: Option<T5LayerCrossAttention>,
|
||||
@ -608,7 +608,7 @@ impl T5Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Stack {
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
@ -658,7 +658,7 @@ impl T5Stack {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5EncoderModel {
|
||||
encoder: T5Stack,
|
||||
device: Device,
|
||||
@ -691,7 +691,7 @@ impl T5EncoderModel {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5ForConditionalGeneration {
|
||||
encoder: T5Stack,
|
||||
decoder: T5Stack,
|
||||
|
Reference in New Issue
Block a user