Fixing quantized llama demo on metal. (#1703)

This commit is contained in:
Nicolas Patry
2024-02-13 16:28:56 +01:00
committed by GitHub
parent ad73e93da2
commit c1b418586c
4 changed files with 34 additions and 13 deletions

View File

@ -233,6 +233,7 @@ pub struct Content {
pub hparams: HParams,
pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
}
impl Content {
@ -252,11 +253,13 @@ impl Content {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor);
}
let device = device.clone();
Ok(Self {
magic,
hparams,
vocab,
tensors,
device,
})
}

View File

@ -14,6 +14,10 @@ impl QMetalStorage {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}

View File

@ -76,6 +76,14 @@ impl QStorage {
}
}
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
#[cfg(feature = "metal")]
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
@ -336,6 +344,10 @@ impl QTensor {
self.storage.dtype()
}
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn rank(&self) -> usize {
self.shape.rank()
}