Add the cuda mode to llama.

This commit is contained in:
laurent
2023-06-26 10:06:44 +01:00
parent 512d12e38d
commit 59a59f41a6
2 changed files with 17 additions and 9 deletions

View File

@ -15,6 +15,7 @@ pub struct VarBuilder {
path: Vec<String>,
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
default_dtype: DType,
default_device: Device,
tensors: Arc<Option<HashMap<String, Tensor>>>,
}
@ -24,13 +25,14 @@ pub struct VarStore {
}
impl VarBuilder {
pub fn new<B: WithDType>(tensors: Option<HashMap<String, Tensor>>) -> Self {
pub fn new<B: WithDType>(device: &Device, tensors: Option<HashMap<String, Tensor>>) -> Self {
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
Self {
path: vec![],
vars,
default_dtype: B::DTYPE,
tensors: Arc::new(tensors),
default_device: device.clone(),
}
}
@ -43,9 +45,9 @@ impl VarBuilder {
let path = format!("{}.{s}", self.path.join("."));
let mut vars = self.vars.borrow_mut();
let parameter = match self.tensors.as_ref() {
None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?,
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?,
Some(tensors) => match tensors.get(&path) {
Some(tensor) => tensor.clone(),
Some(tensor) => tensor.to_device(&self.default_device)?,
None => panic!("cannot find tensor for {path}"),
},
};
@ -75,6 +77,7 @@ impl<S: ToString> std::ops::Div<S> for &VarBuilder {
path,
vars: self.vars.clone(),
default_dtype: self.default_dtype,
default_device: self.default_device.clone(),
tensors: self.tensors.clone(),
}
}