Move the common quantized-nn code to a shared module. (#1063)

This commit is contained in:
Laurent Mazare
2023-10-09 06:22:22 +01:00
committed by GitHub
parent 59ab6d7832
commit 392fe02fba
7 changed files with 100 additions and 166 deletions

View File

@ -2,38 +2,13 @@
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use crate::models::with_tracing::QMatMul;
use crate::quantized_nn::Embedding;
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Activation;
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug)]
pub struct Embedding {
inner: candle_nn::Embedding,
span: tracing::Span,
}
impl Embedding {
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
let inner = candle_nn::Embedding::new(embeddings, d2);
let span = tracing::span!(tracing::Level::TRACE, "embedding");
Ok(Self { inner, span })
}
pub fn embeddings(&self) -> &Tensor {
self.inner.embeddings()
}
}
impl Module for Embedding {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
fn default_relative_attention_max_distance() -> usize {
128
}