Move the common quantized-nn code to a shared module. (#1063)

This commit is contained in:
Laurent Mazare
2023-10-09 06:22:22 +01:00
committed by GitHub
parent 59ab6d7832
commit 392fe02fba
7 changed files with 100 additions and 166 deletions

View File

@ -1,39 +1,9 @@
use super::Config;
use crate::models::{quantized_t5::Embedding, with_tracing::QMatMul};
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, LayerNorm, Module};
#[derive(Debug)]
struct Linear {
weight: QMatMul,
bias: Option<Tensor>,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let x = x.apply(&self.weight)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
let weight = QMatMul::new(in_dim, out_dim, vb)?;
Ok(Linear {
weight,
bias: Some(bias),
})
}
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
let weight = QMatMul::new(in_dim, out_dim, vb)?;
Ok(Linear { weight, bias: None })
}
fn conv1d(
in_channels: usize,
out_channels: usize,
@ -48,12 +18,6 @@ fn conv1d(
Ok(Conv1d::new(weight, Some(bias), config))
}
fn layer_norm(size: usize, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
Ok(candle_nn::LayerNorm::new(weight, bias, 1e-5))
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
struct MultiHeadAttention {
query: Linear,
@ -178,10 +142,10 @@ impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
let attn_ln = layer_norm(n_state, 1e-5, vb.pp("self_attn_layer_norm"))?;
let cross_attn = if ca {
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
let cross_attn_ln = layer_norm(n_state, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
Some((cross_attn, cross_attn_ln))
} else {
None
@ -189,7 +153,7 @@ impl ResidualAttentionBlock {
let n_mlp = n_state * 4;
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
let mlp_ln = layer_norm(n_state, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
attn,
attn_ln,
@ -281,7 +245,7 @@ impl AudioEncoder {
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
let ln_post = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
Ok(Self {
conv1,
conv2,
@ -343,7 +307,7 @@ impl TextDecoder {
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(format!("layers.{i}")))
})
.collect::<Result<Vec<_>>>()?;
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
let ln = layer_norm(n_state, 1e-5, vb.pp("layer_norm"))?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();