From e15862cfdb8b80d0ef3b7b2a3d32c9863e120246 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 24 Sep 2023 12:55:07 +0100 Subject: [PATCH] Shared the quantized var-builder code. (#952) * Shared the quantized var-builder code. * Fix compilation. --- candle-transformers/src/lib.rs | 1 + .../src/models/quantized_t5.rs | 88 ++----------------- .../src/quantized_var_builder.rs | 83 +++++++++++++++++ 3 files changed, 90 insertions(+), 82 deletions(-) create mode 100644 candle-transformers/src/quantized_var_builder.rs diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index b83e5056..a4c7ddf7 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -2,4 +2,5 @@ pub mod generation; pub mod models; pub mod object_detection; pub mod pipelines; +pub mod quantized_var_builder; pub mod utils; diff --git a/candle-transformers/src/models/quantized_t5.rs b/candle-transformers/src/models/quantized_t5.rs index a86dfcb3..d2fa0e2d 100644 --- a/candle-transformers/src/models/quantized_t5.rs +++ b/candle-transformers/src/models/quantized_t5.rs @@ -1,88 +1,12 @@ // T5 Text Model, quantized version // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use candle::quantized::QTensor; -use candle::{DType, Device, Module, Result, Shape, Tensor, D}; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::Activation; use serde::Deserialize; use std::sync::Arc; -// VarBuilder specialized for QTensors -pub struct VarBuilder { - data: Arc>>, - path: Vec, - device: Device, -} - -impl VarBuilder { - pub fn from_gguf>(p: P) -> Result { - let mut file = std::fs::File::open(p)?; - let content = candle::quantized::gguf_file::Content::read(&mut file)?; - let mut data = std::collections::HashMap::new(); - for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut file, tensor_name)?; - data.insert(tensor_name.to_string(), Arc::new(tensor)); - } - Ok(Self { - data: Arc::new(data), - path: Vec::new(), - device: Device::Cpu, - }) - } - - pub fn from_gguf_buffer(buffer: &[u8]) -> Result { - let mut cursor = std::io::Cursor::new(buffer); - let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; - let mut data = std::collections::HashMap::new(); - for tensor_name in content.tensor_infos.keys() { - let tensor = content.tensor(&mut cursor, tensor_name)?; - data.insert(tensor_name.to_string(), Arc::new(tensor)); - } - Ok(Self { - data: Arc::new(data), - path: Vec::new(), - device: Device::Cpu, - }) - } - - fn pp(&self, s: S) -> Self { - let mut path = self.path.clone(); - path.push(s.to_string()); - Self { - data: self.data.clone(), - path, - device: self.device.clone(), - } - } - - fn path(&self, tensor_name: &str) -> String { - if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - } - } - - fn get>(&self, s: S, name: &str) -> Result> { - let path = self.path(name); - match self.data.get(&path) { - None => { - candle::bail!("cannot find tensor {name}") - } - Some(qtensor) => { - let shape = s.into(); - if qtensor.shape() != &shape { - candle::bail!( - "shape mismatch for {name}, got {:?}, expected {shape:?}", - qtensor.shape() - ) - } - Ok(qtensor.clone()) - } - } - } -} - #[derive(Debug)] struct Embedding { inner: candle_nn::Embedding, @@ -91,7 +15,7 @@ struct Embedding { impl Embedding { fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { - let embeddings = vb.get((d1, d2), "weight")?.dequantize(&vb.device)?; + let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; let inner = candle_nn::Embedding::new(embeddings, d2); let span = tracing::span!(tracing::Level::TRACE, "embedding"); Ok(Self { inner, span }) @@ -230,7 +154,7 @@ struct T5LayerNorm { impl T5LayerNorm { fn load(h: usize, eps: f64, vb: VarBuilder) -> Result { - let weight = vb.get(h, "weight")?.dequantize(&vb.device)?; + let weight = vb.get(h, "weight")?.dequantize(vb.device())?; Ok(Self { weight, variance_epsilon: eps, @@ -775,7 +699,7 @@ impl T5EncoderModel { let encoder = T5Stack::load(false, vb.pp("encoder"), &shared, cfg)?; Ok(Self { encoder, - device: vb.device.clone(), + device: vb.device().clone(), span: tracing::span!(tracing::Level::TRACE, "encoder"), }) } @@ -840,7 +764,7 @@ impl T5ForConditionalGeneration { tie_word_embeddings, lm_head, shared, - device: vb.device.clone(), + device: vb.device().clone(), span_decode: tracing::span!(tracing::Level::TRACE, "decode"), span_decode_head: tracing::span!(tracing::Level::TRACE, "decode-head"), }) diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs new file mode 100644 index 00000000..259496d6 --- /dev/null +++ b/candle-transformers/src/quantized_var_builder.rs @@ -0,0 +1,83 @@ +use candle::quantized::QTensor; +use candle::{Device, Result, Shape}; +use std::sync::Arc; + +// VarBuilder specialized for QTensors +pub struct VarBuilder { + data: Arc>>, + path: Vec, + device: Device, +} + +impl VarBuilder { + pub fn from_gguf>(p: P) -> Result { + let mut file = std::fs::File::open(p)?; + let content = candle::quantized::gguf_file::Content::read(&mut file)?; + let mut data = std::collections::HashMap::new(); + for tensor_name in content.tensor_infos.keys() { + let tensor = content.tensor(&mut file, tensor_name)?; + data.insert(tensor_name.to_string(), Arc::new(tensor)); + } + Ok(Self { + data: Arc::new(data), + path: Vec::new(), + device: Device::Cpu, + }) + } + + pub fn from_gguf_buffer(buffer: &[u8]) -> Result { + let mut cursor = std::io::Cursor::new(buffer); + let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; + let mut data = std::collections::HashMap::new(); + for tensor_name in content.tensor_infos.keys() { + let tensor = content.tensor(&mut cursor, tensor_name)?; + data.insert(tensor_name.to_string(), Arc::new(tensor)); + } + Ok(Self { + data: Arc::new(data), + path: Vec::new(), + device: Device::Cpu, + }) + } + + pub fn pp(&self, s: S) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + device: self.device.clone(), + } + } + + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } + } + + pub fn get>(&self, s: S, name: &str) -> Result> { + let path = self.path(name); + match self.data.get(&path) { + None => { + candle::bail!("cannot find tensor {name}") + } + Some(qtensor) => { + let shape = s.into(); + if qtensor.shape() != &shape { + candle::bail!( + "shape mismatch for {name}, got {:?}, expected {shape:?}", + qtensor.shape() + ) + } + Ok(qtensor.clone()) + } + } + } + + pub fn device(&self) -> &Device { + &self.device + } +}