mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Make more models cloneable. (#1203)
This commit is contained in:
@ -93,7 +93,7 @@ impl Default for Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerNorm {
|
||||
weight: Tensor,
|
||||
variance_epsilon: f64,
|
||||
@ -125,7 +125,7 @@ impl Module for T5LayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseActDense {
|
||||
wi: QMatMul,
|
||||
wo: QMatMul,
|
||||
@ -156,7 +156,7 @@ impl Module for T5DenseActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5DenseGatedActDense {
|
||||
wi_0: QMatMul,
|
||||
wi_1: QMatMul,
|
||||
@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerFF {
|
||||
dense_act: Option<T5DenseActDense>,
|
||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||
@ -236,7 +236,7 @@ impl Module for T5LayerFF {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Attention {
|
||||
q: QMatMul,
|
||||
k: QMatMul,
|
||||
@ -431,7 +431,7 @@ impl T5Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerSelfAttention {
|
||||
self_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -470,7 +470,7 @@ impl T5LayerSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5LayerCrossAttention {
|
||||
cross_attention: T5Attention,
|
||||
layer_norm: T5LayerNorm,
|
||||
@ -512,7 +512,7 @@ impl T5LayerCrossAttention {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Block {
|
||||
self_attn: T5LayerSelfAttention,
|
||||
cross_attn: Option<T5LayerCrossAttention>,
|
||||
@ -583,7 +583,7 @@ impl T5Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct T5Stack {
|
||||
block: Vec<T5Block>,
|
||||
shared: Arc<Embedding>,
|
||||
@ -633,7 +633,7 @@ impl T5Stack {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5EncoderModel {
|
||||
encoder: T5Stack,
|
||||
device: Device,
|
||||
@ -666,7 +666,7 @@ impl T5EncoderModel {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct T5ForConditionalGeneration {
|
||||
encoder: T5Stack,
|
||||
decoder: T5Stack,
|
||||
|
Reference in New Issue
Block a user