Shared the quantized var-builder code. (#952)

* Shared the quantized var-builder code.

* Fix compilation.
This commit is contained in:
Laurent Mazare
2023-09-24 12:55:07 +01:00
committed by GitHub
parent 4aeb449017
commit e15862cfdb
3 changed files with 90 additions and 82 deletions

View File

@ -2,4 +2,5 @@ pub mod generation;
pub mod models;
pub mod object_detection;
pub mod pipelines;
pub mod quantized_var_builder;
pub mod utils;

View File

@ -1,88 +1,12 @@
// T5 Text Model, quantized version
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
use candle::quantized::QTensor;
use candle::{DType, Device, Module, Result, Shape, Tensor, D};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::Activation;
use serde::Deserialize;
use std::sync::Arc;
// VarBuilder specialized for QTensors
pub struct VarBuilder {
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
path: Vec<String>,
device: Device,
}
impl VarBuilder {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut file, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
device: Device::Cpu,
})
}
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
device: Device::Cpu,
})
}
fn pp<S: ToString>(&self, s: S) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
Self {
data: self.data.clone(),
path,
device: self.device.clone(),
}
}
fn path(&self, tensor_name: &str) -> String {
if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
}
}
fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
let path = self.path(name);
match self.data.get(&path) {
None => {
candle::bail!("cannot find tensor {name}")
}
Some(qtensor) => {
let shape = s.into();
if qtensor.shape() != &shape {
candle::bail!(
"shape mismatch for {name}, got {:?}, expected {shape:?}",
qtensor.shape()
)
}
Ok(qtensor.clone())
}
}
}
}
#[derive(Debug)]
struct Embedding {
inner: candle_nn::Embedding,
@ -91,7 +15,7 @@ struct Embedding {
impl Embedding {
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?;
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
let inner = candle_nn::Embedding::new(embeddings, d2);
let span = tracing::span!(tracing::Level::TRACE, "embedding");
Ok(Self { inner, span })
@ -230,7 +154,7 @@ struct T5LayerNorm {
impl T5LayerNorm {
fn load(h: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(h, "weight")?.dequantize(&vb.device)?;
let weight = vb.get(h, "weight")?.dequantize(vb.device())?;
Ok(Self {
weight,
variance_epsilon: eps,
@ -775,7 +699,7 @@ impl T5EncoderModel {
let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?;
Ok(Self {
encoder,
device: vb.device.clone(),
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "encoder"),
})
}
@ -840,7 +764,7 @@ impl T5ForConditionalGeneration {
tie_word_embeddings,
lm_head,
shared,
device: vb.device.clone(),
device: vb.device().clone(),
span_decode: tracing::span!(tracing::Level::TRACE, "decode"),
span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"),
})

View File

@ -0,0 +1,83 @@
use candle::quantized::QTensor;
use candle::{Device, Result, Shape};
use std::sync::Arc;
// VarBuilder specialized for QTensors
pub struct VarBuilder {
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
path: Vec<String>,
device: Device,
}
impl VarBuilder {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut file, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
device: Device::Cpu,
})
}
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor));
}
Ok(Self {
data: Arc::new(data),
path: Vec::new(),
device: Device::Cpu,
})
}
pub fn pp<S: ToString>(&self, s: S) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
Self {
data: self.data.clone(),
path,
device: self.device.clone(),
}
}
fn path(&self, tensor_name: &str) -> String {
if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
}
}
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
let path = self.path(name);
match self.data.get(&path) {
None => {
candle::bail!("cannot find tensor {name}")
}
Some(qtensor) => {
let shape = s.into();
if qtensor.shape() != &shape {
candle::bail!(
"shape mismatch for {name}, got {:?}, expected {shape:?}",
qtensor.shape()
)
}
Ok(qtensor.clone())
}
}
}
pub fn device(&self) -> &Device {
&self.device
}
}