mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Make more models cloneable. (#1203)
This commit is contained in:
@ -7,7 +7,7 @@ use std::sync::Arc;
|
|||||||
pub use crate::models::stable_lm::Config;
|
pub use crate::models::stable_lm::Config;
|
||||||
use crate::models::stable_lm::RotaryEmbedding;
|
use crate::models::stable_lm::RotaryEmbedding;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
gate_proj: Linear,
|
gate_proj: Linear,
|
||||||
@ -43,7 +43,7 @@ impl Module for MLP {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct Attention {
|
struct Attention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
@ -168,7 +168,7 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecoderLayer {
|
struct DecoderLayer {
|
||||||
self_attn: Attention,
|
self_attn: Attention,
|
||||||
mlp: MLP,
|
mlp: MLP,
|
||||||
@ -213,7 +213,7 @@ impl DecoderLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
embed_tokens: Embedding,
|
embed_tokens: Embedding,
|
||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<DecoderLayer>,
|
||||||
|
@ -93,7 +93,7 @@ impl Default for Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerNorm {
|
struct T5LayerNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
variance_epsilon: f64,
|
variance_epsilon: f64,
|
||||||
@ -125,7 +125,7 @@ impl Module for T5LayerNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5DenseActDense {
|
struct T5DenseActDense {
|
||||||
wi: QMatMul,
|
wi: QMatMul,
|
||||||
wo: QMatMul,
|
wo: QMatMul,
|
||||||
@ -156,7 +156,7 @@ impl Module for T5DenseActDense {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5DenseGatedActDense {
|
struct T5DenseGatedActDense {
|
||||||
wi_0: QMatMul,
|
wi_0: QMatMul,
|
||||||
wi_1: QMatMul,
|
wi_1: QMatMul,
|
||||||
@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerFF {
|
struct T5LayerFF {
|
||||||
dense_act: Option<T5DenseActDense>,
|
dense_act: Option<T5DenseActDense>,
|
||||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||||
@ -236,7 +236,7 @@ impl Module for T5LayerFF {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5Attention {
|
struct T5Attention {
|
||||||
q: QMatMul,
|
q: QMatMul,
|
||||||
k: QMatMul,
|
k: QMatMul,
|
||||||
@ -431,7 +431,7 @@ impl T5Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerSelfAttention {
|
struct T5LayerSelfAttention {
|
||||||
self_attention: T5Attention,
|
self_attention: T5Attention,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
@ -470,7 +470,7 @@ impl T5LayerSelfAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerCrossAttention {
|
struct T5LayerCrossAttention {
|
||||||
cross_attention: T5Attention,
|
cross_attention: T5Attention,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
@ -512,7 +512,7 @@ impl T5LayerCrossAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5Block {
|
struct T5Block {
|
||||||
self_attn: T5LayerSelfAttention,
|
self_attn: T5LayerSelfAttention,
|
||||||
cross_attn: Option<T5LayerCrossAttention>,
|
cross_attn: Option<T5LayerCrossAttention>,
|
||||||
@ -583,7 +583,7 @@ impl T5Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5Stack {
|
struct T5Stack {
|
||||||
block: Vec<T5Block>,
|
block: Vec<T5Block>,
|
||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
@ -633,7 +633,7 @@ impl T5Stack {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct T5EncoderModel {
|
pub struct T5EncoderModel {
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -666,7 +666,7 @@ impl T5EncoderModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct T5ForConditionalGeneration {
|
pub struct T5ForConditionalGeneration {
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
decoder: T5Stack,
|
decoder: T5Stack,
|
||||||
|
@ -118,7 +118,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerNorm {
|
struct T5LayerNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
variance_epsilon: f64,
|
variance_epsilon: f64,
|
||||||
@ -150,7 +150,7 @@ impl Module for T5LayerNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5DenseActDense {
|
struct T5DenseActDense {
|
||||||
wi: Linear,
|
wi: Linear,
|
||||||
wo: Linear,
|
wo: Linear,
|
||||||
@ -181,7 +181,7 @@ impl Module for T5DenseActDense {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5DenseGatedActDense {
|
struct T5DenseGatedActDense {
|
||||||
wi_0: Linear,
|
wi_0: Linear,
|
||||||
wi_1: Linear,
|
wi_1: Linear,
|
||||||
@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerFF {
|
struct T5LayerFF {
|
||||||
dense_act: Option<T5DenseActDense>,
|
dense_act: Option<T5DenseActDense>,
|
||||||
gated_dense_act: Option<T5DenseGatedActDense>,
|
gated_dense_act: Option<T5DenseGatedActDense>,
|
||||||
@ -261,7 +261,7 @@ impl Module for T5LayerFF {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5Attention {
|
struct T5Attention {
|
||||||
q: Linear,
|
q: Linear,
|
||||||
k: Linear,
|
k: Linear,
|
||||||
@ -456,7 +456,7 @@ impl T5Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerSelfAttention {
|
struct T5LayerSelfAttention {
|
||||||
self_attention: T5Attention,
|
self_attention: T5Attention,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
@ -495,7 +495,7 @@ impl T5LayerSelfAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5LayerCrossAttention {
|
struct T5LayerCrossAttention {
|
||||||
cross_attention: T5Attention,
|
cross_attention: T5Attention,
|
||||||
layer_norm: T5LayerNorm,
|
layer_norm: T5LayerNorm,
|
||||||
@ -537,7 +537,7 @@ impl T5LayerCrossAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5Block {
|
struct T5Block {
|
||||||
self_attn: T5LayerSelfAttention,
|
self_attn: T5LayerSelfAttention,
|
||||||
cross_attn: Option<T5LayerCrossAttention>,
|
cross_attn: Option<T5LayerCrossAttention>,
|
||||||
@ -608,7 +608,7 @@ impl T5Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct T5Stack {
|
struct T5Stack {
|
||||||
block: Vec<T5Block>,
|
block: Vec<T5Block>,
|
||||||
shared: Arc<Embedding>,
|
shared: Arc<Embedding>,
|
||||||
@ -658,7 +658,7 @@ impl T5Stack {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct T5EncoderModel {
|
pub struct T5EncoderModel {
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -691,7 +691,7 @@ impl T5EncoderModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct T5ForConditionalGeneration {
|
pub struct T5ForConditionalGeneration {
|
||||||
encoder: T5Stack,
|
encoder: T5Stack,
|
||||||
decoder: T5Stack,
|
decoder: T5Stack,
|
||||||
|
Reference in New Issue
Block a user