Make some model cloneable. (#1125)

This commit is contained in:
Laurent Mazare
2023-10-18 19:30:47 +01:00
committed by GitHub
parent 620c94d12e
commit 185b54a33b
5 changed files with 25 additions and 20 deletions

View File

@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module};
pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::LayerNorm,
span: tracing::Span,
@ -27,6 +28,7 @@ impl RmsNorm {
}
// QMatMul wrapper adding some tracing.
#[derive(Debug, Clone)]
struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
@ -45,6 +47,7 @@ impl QMatMul {
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
@ -167,6 +170,7 @@ impl LayerWeights {
}
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,