diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index bf5d5b43..4ccbaf17 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -5,14 +5,14 @@ use crate::VarMap; use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; -use std::rc::Rc; +use std::sync::Arc; /// A structure used to retrieve variables, these variables can either come from storage or be /// generated via some form of initialization. /// /// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`. pub struct VarBuilderArgs<'a, B: Backend> { - data: Rc>, + data: Arc>, path: Vec, _phantom: std::marker::PhantomData<&'a B>, } @@ -43,7 +43,7 @@ struct TensorData { /// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most /// of the time. The main restriction is that it doesn't allow for specific args (besides /// initialization hints). -pub trait Backend { +pub trait Backend: Send + Sync { type Hints: Default; /// Retrieve a tensor with some target shape. @@ -59,7 +59,7 @@ pub trait Backend { fn contains_tensor(&self, name: &str) -> bool; } -pub trait SimpleBackend { +pub trait SimpleBackend: Send + Sync { /// Retrieve a tensor based on a target name and shape. fn get( &self, @@ -99,7 +99,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { device: dev.clone(), }; Self { - data: Rc::new(data), + data: Arc::new(data), path: vec![], _phantom: std::marker::PhantomData, } @@ -333,7 +333,7 @@ impl<'a> VarBuilder<'a> { device, }; Self { - data: Rc::new(data), + data: Arc::new(data), path: vec![], _phantom: std::marker::PhantomData, }