mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Make some model cloneable. (#1125)
This commit is contained in:
@ -39,7 +39,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
inner: candle_nn::RmsNorm,
|
inner: candle_nn::RmsNorm,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -60,7 +60,7 @@ impl Module for RmsNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct RotaryEmbedding {
|
struct RotaryEmbedding {
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
@ -111,7 +111,7 @@ impl RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
gate_proj: Linear,
|
gate_proj: Linear,
|
||||||
@ -160,7 +160,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
|||||||
unimplemented!("compile with '--features flash-attn'")
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct Attention {
|
struct Attention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
@ -279,7 +279,7 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecoderLayer {
|
struct DecoderLayer {
|
||||||
self_attn: Attention,
|
self_attn: Attention,
|
||||||
mlp: MLP,
|
mlp: MLP,
|
||||||
@ -322,7 +322,7 @@ impl DecoderLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
embed_tokens: candle_nn::Embedding,
|
embed_tokens: candle_nn::Embedding,
|
||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<DecoderLayer>,
|
||||||
|
@ -74,7 +74,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct Embedding {
|
struct Embedding {
|
||||||
wte: E,
|
wte: E,
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct RotaryEmbedding {
|
struct RotaryEmbedding {
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
@ -172,7 +172,7 @@ impl RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
fc1: Linear,
|
fc1: Linear,
|
||||||
@ -199,7 +199,7 @@ impl Module for MLP {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct CausalLMHead {
|
struct CausalLMHead {
|
||||||
ln: candle_nn::LayerNorm,
|
ln: candle_nn::LayerNorm,
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
@ -221,7 +221,7 @@ impl Module for CausalLMHead {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MHA {
|
struct MHA {
|
||||||
wqkv: Linear,
|
wqkv: Linear,
|
||||||
@ -310,7 +310,7 @@ impl MHA {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct ParallelBlock {
|
struct ParallelBlock {
|
||||||
ln: candle_nn::LayerNorm,
|
ln: candle_nn::LayerNorm,
|
||||||
mixer: MHA,
|
mixer: MHA,
|
||||||
@ -345,7 +345,7 @@ impl ParallelBlock {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MixFormerSequentialForCausalLM {
|
pub struct MixFormerSequentialForCausalLM {
|
||||||
embedding: Embedding,
|
embedding: Embedding,
|
||||||
blocks: Vec<ParallelBlock>,
|
blocks: Vec<ParallelBlock>,
|
||||||
|
@ -40,7 +40,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct GroupedQueryAttention {
|
struct GroupedQueryAttention {
|
||||||
wqkv: Linear,
|
wqkv: Linear,
|
||||||
out_proj: Linear,
|
out_proj: Linear,
|
||||||
@ -148,7 +148,7 @@ pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct Ffn {
|
struct Ffn {
|
||||||
up_proj: Linear,
|
up_proj: Linear,
|
||||||
down_proj: Linear,
|
down_proj: Linear,
|
||||||
@ -169,7 +169,7 @@ impl Module for Ffn {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
struct MPTBlock {
|
struct MPTBlock {
|
||||||
norm1: LayerNorm, // Do we need the low-precision variant?
|
norm1: LayerNorm, // Do we need the low-precision variant?
|
||||||
attn: GroupedQueryAttention,
|
attn: GroupedQueryAttention,
|
||||||
@ -240,7 +240,7 @@ pub(crate) fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
|
|||||||
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
|
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
wte: Embedding,
|
wte: Embedding,
|
||||||
blocks: Vec<MPTBlock>,
|
blocks: Vec<MPTBlock>,
|
||||||
|
@ -7,6 +7,7 @@ use candle_nn::{Embedding, Module};
|
|||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
inner: candle_nn::LayerNorm,
|
inner: candle_nn::LayerNorm,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -27,6 +28,7 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QMatMul wrapper adding some tracing.
|
// QMatMul wrapper adding some tracing.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct QMatMul {
|
struct QMatMul {
|
||||||
inner: candle::quantized::QMatMul,
|
inner: candle::quantized::QMatMul,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -45,6 +47,7 @@ impl QMatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
struct LayerWeights {
|
struct LayerWeights {
|
||||||
attention_wq: QMatMul,
|
attention_wq: QMatMul,
|
||||||
attention_wk: QMatMul,
|
attention_wk: QMatMul,
|
||||||
@ -167,6 +170,7 @@ impl LayerWeights {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct ModelWeights {
|
pub struct ModelWeights {
|
||||||
tok_embeddings: Embedding,
|
tok_embeddings: Embedding,
|
||||||
layers: Vec<LayerWeights>,
|
layers: Vec<LayerWeights>,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use candle::{Module, Result, Tensor};
|
use candle::{Module, Result, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Embedding {
|
pub struct Embedding {
|
||||||
inner: candle_nn::Embedding,
|
inner: candle_nn::Embedding,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -26,7 +26,7 @@ impl Module for Embedding {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Linear {
|
pub struct Linear {
|
||||||
inner: candle_nn::Linear,
|
inner: candle_nn::Linear,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -52,7 +52,7 @@ impl Module for Linear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the conv2d op to provide some tracing.
|
// Wrap the conv2d op to provide some tracing.
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Conv2d {
|
pub struct Conv2d {
|
||||||
inner: candle_nn::Conv2d,
|
inner: candle_nn::Conv2d,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -78,6 +78,7 @@ pub fn conv2d(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// QMatMul wrapper adding some tracing.
|
// QMatMul wrapper adding some tracing.
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct QMatMul {
|
pub struct QMatMul {
|
||||||
inner: candle::quantized::QMatMul,
|
inner: candle::quantized::QMatMul,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
Reference in New Issue
Block a user