mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
98 lines
3.0 KiB
Rust
98 lines
3.0 KiB
Rust
use candle::quantized::QTensor;
|
|
use candle::{Device, Result, Shape};
|
|
use std::sync::Arc;
|
|
|
|
// VarBuilder specialized for QTensors
|
|
pub struct VarBuilder {
|
|
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
|
path: Vec<String>,
|
|
device: Device,
|
|
}
|
|
|
|
impl VarBuilder {
|
|
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
|
let mut file = std::fs::File::open(p)?;
|
|
let content = candle::quantized::gguf_file::Content::read(&mut file)?;
|
|
let mut data = std::collections::HashMap::new();
|
|
for tensor_name in content.tensor_infos.keys() {
|
|
let tensor = content.tensor(&mut file, tensor_name)?;
|
|
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
|
}
|
|
Ok(Self {
|
|
data: Arc::new(data),
|
|
path: Vec::new(),
|
|
device: Device::Cpu,
|
|
})
|
|
}
|
|
|
|
pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
|
|
let mut cursor = std::io::Cursor::new(buffer);
|
|
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
|
|
let mut data = std::collections::HashMap::new();
|
|
for tensor_name in content.tensor_infos.keys() {
|
|
let tensor = content.tensor(&mut cursor, tensor_name)?;
|
|
data.insert(tensor_name.to_string(), Arc::new(tensor));
|
|
}
|
|
Ok(Self {
|
|
data: Arc::new(data),
|
|
path: Vec::new(),
|
|
device: Device::Cpu,
|
|
})
|
|
}
|
|
|
|
pub fn pp<S: ToString>(&self, s: S) -> Self {
|
|
let mut path = self.path.clone();
|
|
path.push(s.to_string());
|
|
Self {
|
|
data: self.data.clone(),
|
|
path,
|
|
device: self.device.clone(),
|
|
}
|
|
}
|
|
|
|
fn path(&self, tensor_name: &str) -> String {
|
|
if self.path.is_empty() {
|
|
tensor_name.to_string()
|
|
} else {
|
|
[&self.path.join("."), tensor_name].join(".")
|
|
}
|
|
}
|
|
|
|
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Arc<QTensor>> {
|
|
let path = self.path(name);
|
|
match self.data.get(&path) {
|
|
None => {
|
|
candle::bail!("cannot find tensor {name}")
|
|
}
|
|
Some(qtensor) => {
|
|
let shape = s.into();
|
|
if qtensor.shape() != &shape {
|
|
candle::bail!(
|
|
"shape mismatch for {name}, got {:?}, expected {shape:?}",
|
|
qtensor.shape()
|
|
)
|
|
}
|
|
Ok(qtensor.clone())
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
|
|
let path = self.path(name);
|
|
match self.data.get(&path) {
|
|
None => {
|
|
candle::bail!("cannot find tensor {name}")
|
|
}
|
|
Some(qtensor) => Ok(qtensor.clone()),
|
|
}
|
|
}
|
|
|
|
pub fn device(&self) -> &Device {
|
|
&self.device
|
|
}
|
|
|
|
pub fn contains_key(&self, key: &str) -> bool {
|
|
self.data.contains_key(key)
|
|
}
|
|
}
|