mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Move the common quantized-nn code to a shared module. (#1063)
This commit is contained in:
@ -2,38 +2,13 @@
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use crate::models::with_tracing::QMatMul;
|
||||
use crate::quantized_nn::Embedding;
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::Activation;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedding {
|
||||
inner: candle_nn::Embedding,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
||||
let inner = candle_nn::Embedding::new(embeddings, d2);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
self.inner.embeddings()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Embedding {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_relative_attention_max_distance() -> usize {
|
||||
128
|
||||
}
|
||||
|
Reference in New Issue
Block a user