mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the cuda mode to llama.
This commit is contained in:
@ -15,6 +15,7 @@ pub struct VarBuilder {
|
||||
path: Vec<String>,
|
||||
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
|
||||
default_dtype: DType,
|
||||
default_device: Device,
|
||||
tensors: Arc<Option<HashMap<String, Tensor>>>,
|
||||
}
|
||||
|
||||
@ -24,13 +25,14 @@ pub struct VarStore {
|
||||
}
|
||||
|
||||
impl VarBuilder {
|
||||
pub fn new<B: WithDType>(tensors: Option<HashMap<String, Tensor>>) -> Self {
|
||||
pub fn new<B: WithDType>(device: &Device, tensors: Option<HashMap<String, Tensor>>) -> Self {
|
||||
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
|
||||
Self {
|
||||
path: vec![],
|
||||
vars,
|
||||
default_dtype: B::DTYPE,
|
||||
tensors: Arc::new(tensors),
|
||||
default_device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,9 +45,9 @@ impl VarBuilder {
|
||||
let path = format!("{}.{s}", self.path.join("."));
|
||||
let mut vars = self.vars.borrow_mut();
|
||||
let parameter = match self.tensors.as_ref() {
|
||||
None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?,
|
||||
None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?,
|
||||
Some(tensors) => match tensors.get(&path) {
|
||||
Some(tensor) => tensor.clone(),
|
||||
Some(tensor) => tensor.to_device(&self.default_device)?,
|
||||
None => panic!("cannot find tensor for {path}"),
|
||||
},
|
||||
};
|
||||
@ -75,6 +77,7 @@ impl<S: ToString> std::ops::Div<S> for &VarBuilder {
|
||||
path,
|
||||
vars: self.vars.clone(),
|
||||
default_dtype: self.default_dtype,
|
||||
default_device: self.default_device.clone(),
|
||||
tensors: self.tensors.clone(),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user