Make the falcon model cloneable. (#2067)

This commit is contained in:
Laurent Mazare
2024-04-15 09:39:03 +02:00
committed by GitHub
parent 8ad822a983
commit af955f260c

View File

@ -120,7 +120,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
Ok(x21) Ok(x21)
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct FalconRotaryEmbedding { struct FalconRotaryEmbedding {
inv_freq: Tensor, inv_freq: Tensor,
cache: Option<(usize, Tensor, Tensor)>, cache: Option<(usize, Tensor, Tensor)>,
@ -186,7 +186,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct FalconAttention { struct FalconAttention {
query_key_value: Linear, query_key_value: Linear,
dense: Linear, dense: Linear,
@ -321,7 +321,7 @@ impl FalconAttention {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct FalconMlp { struct FalconMlp {
dense_h_to_4h: Linear, dense_h_to_4h: Linear,
dense_4h_to_h: Linear, dense_4h_to_h: Linear,
@ -346,7 +346,7 @@ impl FalconMlp {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
struct FalconDecoderLayer { struct FalconDecoderLayer {
inp_layernorm: LayerNorm, inp_layernorm: LayerNorm,
self_attention: FalconAttention, self_attention: FalconAttention,
@ -412,7 +412,7 @@ impl FalconDecoderLayer {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Falcon { pub struct Falcon {
word_embeddings: Embedding, word_embeddings: Embedding,
blocks: Vec<FalconDecoderLayer>, blocks: Vec<FalconDecoderLayer>,