mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
* quantized models(awq/squeezellm/...) have multiple data type tensors, use 'get_with_hints_dtype' to load tensors with given dtype
This commit is contained in:
@ -178,16 +178,27 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
name: &str,
|
name: &str,
|
||||||
hints: B::Hints,
|
hints: B::Hints,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let path = self.path(name);
|
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
|
||||||
self.data
|
|
||||||
.backend
|
|
||||||
.get(s.into(), &path, hints, self.data.dtype, &self.data.device)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieve the tensor associated with the given name at the current path.
|
/// Retrieve the tensor associated with the given name at the current path.
|
||||||
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
|
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
|
||||||
self.get_with_hints(s, name, Default::default())
|
self.get_with_hints(s, name, Default::default())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve the tensor associated with the given name & dtype at the current path.
|
||||||
|
pub fn get_with_hints_dtype<S: Into<Shape>>(
|
||||||
|
&self,
|
||||||
|
s: S,
|
||||||
|
name: &str,
|
||||||
|
hints: B::Hints,
|
||||||
|
dtype: DType,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let path = self.path(name);
|
||||||
|
self.data
|
||||||
|
.backend
|
||||||
|
.get(s.into(), &path, hints, dtype, &self.data.device)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Zeros;
|
struct Zeros;
|
||||||
|
Reference in New Issue
Block a user