mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Make some model cloneable. (#1125)
This commit is contained in:
@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::LayerNorm,
|
||||
span: tracing::Span,
|
||||
@ -27,6 +28,7 @@ impl RmsNorm {
|
||||
}
|
||||
|
||||
// QMatMul wrapper adding some tracing.
|
||||
#[derive(Debug, Clone)]
|
||||
struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
@ -45,6 +47,7 @@ impl QMatMul {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
@ -167,6 +170,7 @@ impl LayerWeights {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
|
Reference in New Issue
Block a user