Use an arc in the varbuilder rather than rc. (#757)

* Use an arc in the varbuilder rather than rc.

* Require the backends to be send.

* Request send and sync.
This commit is contained in:
Laurent Mazare
2023-09-06 16:29:09 +02:00
committed by GitHub
parent dcf708559d
commit bdc9d46fe3

View File

@ -5,14 +5,14 @@ use crate::VarMap;
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
use safetensors::{slice::IndexOp, tensor::SafeTensors}; use safetensors::{slice::IndexOp, tensor::SafeTensors};
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::sync::Arc;
/// A structure used to retrieve variables, these variables can either come from storage or be /// A structure used to retrieve variables, these variables can either come from storage or be
/// generated via some form of initialization. /// generated via some form of initialization.
/// ///
/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`. /// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`.
pub struct VarBuilderArgs<'a, B: Backend> { pub struct VarBuilderArgs<'a, B: Backend> {
data: Rc<TensorData<B>>, data: Arc<TensorData<B>>,
path: Vec<String>, path: Vec<String>,
_phantom: std::marker::PhantomData<&'a B>, _phantom: std::marker::PhantomData<&'a B>,
} }
@ -43,7 +43,7 @@ struct TensorData<B: Backend> {
/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most /// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most
/// of the time. The main restriction is that it doesn't allow for specific args (besides /// of the time. The main restriction is that it doesn't allow for specific args (besides
/// initialization hints). /// initialization hints).
pub trait Backend { pub trait Backend: Send + Sync {
type Hints: Default; type Hints: Default;
/// Retrieve a tensor with some target shape. /// Retrieve a tensor with some target shape.
@ -59,7 +59,7 @@ pub trait Backend {
fn contains_tensor(&self, name: &str) -> bool; fn contains_tensor(&self, name: &str) -> bool;
} }
pub trait SimpleBackend { pub trait SimpleBackend: Send + Sync {
/// Retrieve a tensor based on a target name and shape. /// Retrieve a tensor based on a target name and shape.
fn get( fn get(
&self, &self,
@ -99,7 +99,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
device: dev.clone(), device: dev.clone(),
}; };
Self { Self {
data: Rc::new(data), data: Arc::new(data),
path: vec![], path: vec![],
_phantom: std::marker::PhantomData, _phantom: std::marker::PhantomData,
} }
@ -333,7 +333,7 @@ impl<'a> VarBuilder<'a> {
device, device,
}; };
Self { Self {
data: Rc::new(data), data: Arc::new(data),
path: vec![], path: vec![],
_phantom: std::marker::PhantomData, _phantom: std::marker::PhantomData,
} }