mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Shared the quantized var-builder code. (#952)
* Shared the quantized var-builder code. * Fix compilation.
This commit is contained in:
@ -2,4 +2,5 @@ pub mod generation;
|
||||
pub mod models;
|
||||
pub mod object_detection;
|
||||
pub mod pipelines;
|
||||
pub mod quantized_var_builder;
|
||||
pub mod utils;
|
||||
|
@ -1,88 +1,12 @@
|
||||
// T5 Text Model, quantized version
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, Module, Result, Shape, Tensor, D};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::Activation;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
// VarBuilder specialized for QTensors
|
||||
pub struct VarBuilder {
|
||||
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
||||
path: Vec<String>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let mut file = std::fs::File::open(p)?;
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(buffer);
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
})
|
||||
}
|
||||
|
||||
fn pp<S: ToString>(&self, s: S) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path,
|
||||
device: self.device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn path(&self, tensor_name: &str) -> String {
|
||||
if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
} else {
|
||||
[&self.path.join("."), tensor_name].join(".")
|
||||
}
|
||||
}
|
||||
|
||||
fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
|
||||
let path = self.path(name);
|
||||
match self.data.get(&path) {
|
||||
None => {
|
||||
candle::bail!("cannot find tensor {name}")
|
||||
}
|
||||
Some(qtensor) => {
|
||||
let shape = s.into();
|
||||
if qtensor.shape() != &shape {
|
||||
candle::bail!(
|
||||
"shape mismatch for {name}, got {:?}, expected {shape:?}",
|
||||
qtensor.shape()
|
||||
)
|
||||
}
|
||||
Ok(qtensor.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Embedding {
|
||||
inner: candle_nn::Embedding,
|
||||
@ -91,7 +15,7 @@ struct Embedding {
|
||||
|
||||
impl Embedding {
|
||||
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?;
|
||||
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
||||
let inner = candle_nn::Embedding::new(embeddings, d2);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||
Ok(Self { inner, span })
|
||||
@ -230,7 +154,7 @@ struct T5LayerNorm {
|
||||
|
||||
impl T5LayerNorm {
|
||||
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(h, "weight")?.dequantize(&vb.device)?;
|
||||
let weight = vb.get(h, "weight")?.dequantize(vb.device())?;
|
||||
Ok(Self {
|
||||
weight,
|
||||
variance_epsilon: eps,
|
||||
@ -775,7 +699,7 @@ impl T5EncoderModel {
|
||||
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
device: vb.device.clone(),
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "encoder"),
|
||||
})
|
||||
}
|
||||
@ -840,7 +764,7 @@ impl T5ForConditionalGeneration {
|
||||
tie_word_embeddings,
|
||||
lm_head,
|
||||
shared,
|
||||
device: vb.device.clone(),
|
||||
device: vb.device().clone(),
|
||||
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
|
||||
span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
|
||||
})
|
||||
|
83
candle-transformers/src/quantized_var_builder.rs
Normal file
83
candle-transformers/src/quantized_var_builder.rs
Normal file
@ -0,0 +1,83 @@
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{Device, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
// VarBuilder specialized for QTensors
|
||||
pub struct VarBuilder {
|
||||
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
||||
path: Vec<String>,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let mut file = std::fs::File::open(p)?;
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut file, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
||||
let mut cursor = std::io::Cursor::new(buffer);
|
||||
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
||||
let mut data = std::collections::HashMap::new();
|
||||
for tensor_name in content.tensor_infos.keys() {
|
||||
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
||||
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
||||
}
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: Vec::new(),
|
||||
device: Device::Cpu,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn pp<S: ToString>(&self, s: S) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path,
|
||||
device: self.device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn path(&self, tensor_name: &str) -> String {
|
||||
if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
} else {
|
||||
[&self.path.join("."), tensor_name].join(".")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
|
||||
let path = self.path(name);
|
||||
match self.data.get(&path) {
|
||||
None => {
|
||||
candle::bail!("cannot find tensor {name}")
|
||||
}
|
||||
Some(qtensor) => {
|
||||
let shape = s.into();
|
||||
if qtensor.shape() != &shape {
|
||||
candle::bail!(
|
||||
"shape mismatch for {name}, got {:?}, expected {shape:?}",
|
||||
qtensor.shape()
|
||||
)
|
||||
}
|
||||
Ok(qtensor.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.device
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user