From 85bea43e5b088b94612b0fd7ed8f09261dc79d52 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 27 Oct 2023 17:59:19 +0200 Subject: [PATCH] Make the whisper model cloneable (#1200) * Add a quantized variant of llama2.c * Clippy fixes. * Make the whisper model cloneable. --- candle-transformers/src/models/whisper/model.rs | 7 ++++++- candle-transformers/src/models/whisper/quantized_model.rs | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 2a58afaf..6078944c 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -9,7 +9,7 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result Result { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +#[derive(Debug, Clone)] struct MultiHeadAttention { query: Linear, key: Linear, @@ -162,6 +163,7 @@ impl MultiHeadAttention { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +#[derive(Debug, Clone)] struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, @@ -241,6 +243,7 @@ fn sinusoids(length: usize, channels: usize) -> Result { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +#[derive(Debug, Clone)] pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, @@ -316,6 +319,7 @@ impl AudioEncoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +#[derive(Debug, Clone)] pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, @@ -380,6 +384,7 @@ impl TextDecoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +#[derive(Debug, Clone)] pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder, diff --git a/candle-transformers/src/models/whisper/quantized_model.rs b/candle-transformers/src/models/whisper/quantized_model.rs index f0aead49..43ea4177 100644 --- a/candle-transformers/src/models/whisper/quantized_model.rs +++ b/candle-transformers/src/models/whisper/quantized_model.rs @@ -19,6 +19,7 @@ fn conv1d( } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +#[derive(Debug, Clone)] struct MultiHeadAttention { query: Linear, key: Linear, @@ -128,6 +129,7 @@ impl MultiHeadAttention { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +#[derive(Debug, Clone)] struct ResidualAttentionBlock { attn: MultiHeadAttention, attn_ln: LayerNorm, @@ -206,6 +208,7 @@ fn sinusoids(length: usize, channels: usize) -> Result { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +#[derive(Debug, Clone)] pub struct AudioEncoder { conv1: Conv1d, conv2: Conv1d, @@ -281,6 +284,7 @@ impl AudioEncoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +#[derive(Debug, Clone)] pub struct TextDecoder { token_embedding: Embedding, positional_embedding: Tensor, @@ -347,6 +351,7 @@ impl TextDecoder { } // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +#[derive(Debug, Clone)] pub struct Whisper { pub encoder: AudioEncoder, pub decoder: TextDecoder,