use candle::{DType, Device, Result, Shape, Tensor, WithDType}; use std::collections::HashMap; use std::sync::Arc; #[allow(dead_code)] #[derive(Clone)] struct NamedVar { path: String, dtype: DType, shape: Shape, } #[derive(Clone)] pub struct VarBuilder { path: Vec, vars: std::rc::Rc>>, default_dtype: DType, default_device: Device, tensors: Arc>>, } #[allow(dead_code)] pub struct VarStore { vars: Vec, } impl VarBuilder { pub fn new(device: &Device, tensors: Option>) -> Self { let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![])); Self { path: vec![], vars, default_dtype: B::DTYPE, tensors: Arc::new(tensors), default_device: device.clone(), } } pub fn len(&self) -> usize { self.vars.borrow().len() } pub fn var>(&mut self, s: &str, shape: S) -> Result { let shape = shape.into(); let path = format!("{}.{s}", self.path.join(".")); let mut vars = self.vars.borrow_mut(); let parameter = match self.tensors.as_ref() { None => Tensor::zeros(&shape, self.default_dtype, &self.default_device)?, Some(tensors) => match tensors.get(&path) { Some(tensor) => tensor.to_device(&self.default_device)?, None => panic!("cannot find tensor for {path}"), }, }; vars.push(NamedVar { path, dtype: self.default_dtype, shape, }); Ok(parameter) } pub fn into_store(self) -> VarStore { let vars = self.vars.borrow(); VarStore { vars: vars.to_vec(), } } } impl std::ops::Div for &VarBuilder { type Output = VarBuilder; fn div(self, rhs: S) -> VarBuilder { let mut path = self.path.clone(); path.push(rhs.to_string()); VarBuilder { path, vars: self.vars.clone(), default_dtype: self.default_dtype, default_device: self.default_device.clone(), tensors: self.tensors.clone(), } } } impl std::ops::Div for VarBuilder { type Output = VarBuilder; fn div(self, rhs: S) -> VarBuilder { &self / rhs } }