Add a quantized variant of llama2.c (#1197)

* Add a quantized variant of llama2.c

* Clippy fixes.
This commit is contained in:
Laurent Mazare
2023-10-27 15:34:06 +01:00
committed by GitHub
parent 916619f70b
commit e2826e70b3
5 changed files with 287 additions and 38 deletions

View File

@ -36,9 +36,9 @@ pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub cos: Tensor,
pub sin: Tensor,
device: Device,
}
@ -75,7 +75,7 @@ impl Cache {
})
}
fn mask(&self, t: usize) -> Result<Tensor> {
pub fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())