mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Shared the quantized var-builder code. (#952)
* Shared the quantized var-builder code. * Fix compilation.
This commit is contained in:
@ -2,4 +2,5 @@ pub mod generation;
|
|||||||
pub mod models;
|
pub mod models;
|
||||||
pub mod object_detection;
|
pub mod object_detection;
|
||||||
pub mod pipelines;
|
pub mod pipelines;
|
||||||
|
pub mod quantized_var_builder;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
@ -1,88 +1,12 @@
|
|||||||
// T5 Text Model, quantized version
|
// T5 Text Model, quantized version
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
use candle::quantized::QTensor;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, Module, Result, Shape, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::Activation;
|
use candle_nn::Activation;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
// VarBuilder specialized for QTensors
|
|
||||||
pub struct VarBuilder {
|
|
||||||
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
|
||||||
path: Vec<String>,
|
|
||||||
device: Device,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VarBuilder {
|
|
||||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
|
||||||
let mut file = std::fs::File::open(p)?;
|
|
||||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
|
||||||
let mut data = std::collections::HashMap::new();
|
|
||||||
for tensor_name in content.tensor_infos.keys() {
|
|
||||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
|
||||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
|
||||||
}
|
|
||||||
Ok(Self {
|
|
||||||
data: Arc::new(data),
|
|
||||||
path: Vec::new(),
|
|
||||||
device: Device::Cpu,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
|
||||||
let mut cursor = std::io::Cursor::new(buffer);
|
|
||||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
|
||||||
let mut data = std::collections::HashMap::new();
|
|
||||||
for tensor_name in content.tensor_infos.keys() {
|
|
||||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
|
||||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
|
||||||
}
|
|
||||||
Ok(Self {
|
|
||||||
data: Arc::new(data),
|
|
||||||
path: Vec::new(),
|
|
||||||
device: Device::Cpu,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn pp<S: ToString>(&self, s: S) -> Self {
|
|
||||||
let mut path = self.path.clone();
|
|
||||||
path.push(s.to_string());
|
|
||||||
Self {
|
|
||||||
data: self.data.clone(),
|
|
||||||
path,
|
|
||||||
device: self.device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn path(&self, tensor_name: &str) -> String {
|
|
||||||
if self.path.is_empty() {
|
|
||||||
tensor_name.to_string()
|
|
||||||
} else {
|
|
||||||
[&self.path.join("."), tensor_name].join(".")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
|
|
||||||
let path = self.path(name);
|
|
||||||
match self.data.get(&path) {
|
|
||||||
None => {
|
|
||||||
candle::bail!("cannot find tensor {name}")
|
|
||||||
}
|
|
||||||
Some(qtensor) => {
|
|
||||||
let shape = s.into();
|
|
||||||
if qtensor.shape() != &shape {
|
|
||||||
candle::bail!(
|
|
||||||
"shape mismatch for {name}, got {:?}, expected {shape:?}",
|
|
||||||
qtensor.shape()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
Ok(qtensor.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Embedding {
|
struct Embedding {
|
||||||
inner: candle_nn::Embedding,
|
inner: candle_nn::Embedding,
|
||||||
@ -91,7 +15,7 @@ struct Embedding {
|
|||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?;
|
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
||||||
let inner = candle_nn::Embedding::new(embeddings, d2);
|
let inner = candle_nn::Embedding::new(embeddings, d2);
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||||
Ok(Self { inner, span })
|
Ok(Self { inner, span })
|
||||||
@ -230,7 +154,7 @@ struct T5LayerNorm {
|
|||||||
|
|
||||||
impl T5LayerNorm {
|
impl T5LayerNorm {
|
||||||
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||||
let weight = vb.get(h, "weight")?.dequantize(&vb.device)?;
|
let weight = vb.get(h, "weight")?.dequantize(vb.device())?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
weight,
|
weight,
|
||||||
variance_epsilon: eps,
|
variance_epsilon: eps,
|
||||||
@ -775,7 +699,7 @@ impl T5EncoderModel {
|
|||||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
encoder,
|
encoder,
|
||||||
device: vb.device.clone(),
|
device: vb.device().clone(),
|
||||||
span: tracing::span!(tracing::Level::TRACE, "encoder"),
|
span: tracing::span!(tracing::Level::TRACE, "encoder"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -840,7 +764,7 @@ impl T5ForConditionalGeneration {
|
|||||||
tie_word_embeddings,
|
tie_word_embeddings,
|
||||||
lm_head,
|
lm_head,
|
||||||
shared,
|
shared,
|
||||||
device: vb.device.clone(),
|
device: vb.device().clone(),
|
||||||
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
|
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
|
||||||
span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
|
span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
|
||||||
})
|
})
|
||||||
|
83
candle-transformers/src/quantized_var_builder.rs
Normal file
83
candle-transformers/src/quantized_var_builder.rs
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
use candle::quantized::QTensor;
|
||||||
|
use candle::{Device, Result, Shape};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
// VarBuilder specialized for QTensors
|
||||||
|
pub struct VarBuilder {
|
||||||
|
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
||||||
|
path: Vec<String>,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VarBuilder {
|
||||||
|
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||||
|
let mut file = std::fs::File::open(p)?;
|
||||||
|
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||||
|
let mut data = std::collections::HashMap::new();
|
||||||
|
for tensor_name in content.tensor_infos.keys() {
|
||||||
|
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||||
|
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
data: Arc::new(data),
|
||||||
|
path: Vec::new(),
|
||||||
|
device: Device::Cpu,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||||
|
let mut cursor = std::io::Cursor::new(buffer);
|
||||||
|
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||||
|
let mut data = std::collections::HashMap::new();
|
||||||
|
for tensor_name in content.tensor_infos.keys() {
|
||||||
|
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||||
|
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
data: Arc::new(data),
|
||||||
|
path: Vec::new(),
|
||||||
|
device: Device::Cpu,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pp<S: ToString>(&self, s: S) -> Self {
|
||||||
|
let mut path = self.path.clone();
|
||||||
|
path.push(s.to_string());
|
||||||
|
Self {
|
||||||
|
data: self.data.clone(),
|
||||||
|
path,
|
||||||
|
device: self.device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn path(&self, tensor_name: &str) -> String {
|
||||||
|
if self.path.is_empty() {
|
||||||
|
tensor_name.to_string()
|
||||||
|
} else {
|
||||||
|
[&self.path.join("."), tensor_name].join(".")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
|
||||||
|
let path = self.path(name);
|
||||||
|
match self.data.get(&path) {
|
||||||
|
None => {
|
||||||
|
candle::bail!("cannot find tensor {name}")
|
||||||
|
}
|
||||||
|
Some(qtensor) => {
|
||||||
|
let shape = s.into();
|
||||||
|
if qtensor.shape() != &shape {
|
||||||
|
candle::bail!(
|
||||||
|
"shape mismatch for {name}, got {:?}, expected {shape:?}",
|
||||||
|
qtensor.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Ok(qtensor.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user