mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Preliminary support for SDXL. (#647)
* Preliminary support for SDXL. * More SDXL support. * More SDXL. * Use the proper clip config. * Querying for existing tensors. * More robust test.
This commit is contained in:
@ -52,6 +52,8 @@ pub trait Backend {
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor>;
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool;
|
||||
}
|
||||
|
||||
pub trait SimpleBackend {
|
||||
@ -64,6 +66,8 @@ pub trait SimpleBackend {
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor>;
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool;
|
||||
}
|
||||
|
||||
impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
||||
@ -78,6 +82,10 @@ impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
||||
) -> Result<Tensor> {
|
||||
self.as_ref().get(s, name, h, dtype, dev)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.as_ref().contains_tensor(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
@ -94,6 +102,8 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a new `VarBuilder` adding `s` to the current prefix. This can be think of as `cd`
|
||||
/// into a directory.
|
||||
pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
@ -109,10 +119,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
self.push_prefix(s)
|
||||
}
|
||||
|
||||
/// The device used by default.
|
||||
pub fn device(&self) -> &Device {
|
||||
&self.data.device
|
||||
}
|
||||
|
||||
/// The dtype used by default.
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.data.dtype
|
||||
}
|
||||
@ -125,6 +137,14 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
}
|
||||
}
|
||||
|
||||
/// This returns true only if a tensor with the passed in name is available. E.g. when passed
|
||||
/// `a`, true is returned if `prefix.a` exists but false is returned if only `prefix.a.b`
|
||||
/// exists.
|
||||
pub fn contains_tensor(&self, tensor_name: &str) -> bool {
|
||||
let path = self.path(tensor_name);
|
||||
self.data.backend.contains_tensor(&path)
|
||||
}
|
||||
|
||||
/// Retrieve the tensor associated with the given name at the current path.
|
||||
pub fn get_with_hints<S: Into<Shape>>(
|
||||
&self,
|
||||
@ -149,6 +169,10 @@ impl SimpleBackend for Zeros {
|
||||
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
|
||||
Tensor::zeros(s, dtype, dev)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, _name: &str) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for HashMap<String, Tensor> {
|
||||
@ -179,6 +203,10 @@ impl SimpleBackend for HashMap<String, Tensor> {
|
||||
}
|
||||
tensor.to_device(dev)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for VarMap {
|
||||
@ -192,6 +220,10 @@ impl SimpleBackend for VarMap {
|
||||
) -> Result<Tensor> {
|
||||
VarMap::get(self, s, name, h, dtype, dev)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.data().lock().unwrap().contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
struct SafeTensorWithRouting<'a> {
|
||||
@ -228,6 +260,10 @@ impl<'a> SimpleBackend for SafeTensorWithRouting<'a> {
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.routing.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for candle::npy::NpzTensors {
|
||||
@ -257,6 +293,10 @@ impl SimpleBackend for candle::npy::NpzTensors {
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.get(name).map_or(false, |v| v.is_some())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
@ -425,4 +465,8 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.0.routing.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user