mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Use an arc in the varbuilder rather than rc. (#757)
* Use an arc in the varbuilder rather than rc. * Require the backends to be send. * Request send and sync.
This commit is contained in:
@ -5,14 +5,14 @@ use crate::VarMap;
|
|||||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||||
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::rc::Rc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// A structure used to retrieve variables, these variables can either come from storage or be
|
/// A structure used to retrieve variables, these variables can either come from storage or be
|
||||||
/// generated via some form of initialization.
|
/// generated via some form of initialization.
|
||||||
///
|
///
|
||||||
/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`.
|
/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`.
|
||||||
pub struct VarBuilderArgs<'a, B: Backend> {
|
pub struct VarBuilderArgs<'a, B: Backend> {
|
||||||
data: Rc<TensorData<B>>,
|
data: Arc<TensorData<B>>,
|
||||||
path: Vec<String>,
|
path: Vec<String>,
|
||||||
_phantom: std::marker::PhantomData<&'a B>,
|
_phantom: std::marker::PhantomData<&'a B>,
|
||||||
}
|
}
|
||||||
@ -43,7 +43,7 @@ struct TensorData<B: Backend> {
|
|||||||
/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most
|
/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most
|
||||||
/// of the time. The main restriction is that it doesn't allow for specific args (besides
|
/// of the time. The main restriction is that it doesn't allow for specific args (besides
|
||||||
/// initialization hints).
|
/// initialization hints).
|
||||||
pub trait Backend {
|
pub trait Backend: Send + Sync {
|
||||||
type Hints: Default;
|
type Hints: Default;
|
||||||
|
|
||||||
/// Retrieve a tensor with some target shape.
|
/// Retrieve a tensor with some target shape.
|
||||||
@ -59,7 +59,7 @@ pub trait Backend {
|
|||||||
fn contains_tensor(&self, name: &str) -> bool;
|
fn contains_tensor(&self, name: &str) -> bool;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait SimpleBackend {
|
pub trait SimpleBackend: Send + Sync {
|
||||||
/// Retrieve a tensor based on a target name and shape.
|
/// Retrieve a tensor based on a target name and shape.
|
||||||
fn get(
|
fn get(
|
||||||
&self,
|
&self,
|
||||||
@ -99,7 +99,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
device: dev.clone(),
|
device: dev.clone(),
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
data: Rc::new(data),
|
data: Arc::new(data),
|
||||||
path: vec![],
|
path: vec![],
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
@ -333,7 +333,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
device,
|
device,
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
data: Rc::new(data),
|
data: Arc::new(data),
|
||||||
path: vec![],
|
path: vec![],
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user