mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add Pixtral. (#2521)
* Add Pixtral. * More pixtral vision encoder. * Sketch a pixtral example. * Sketch a pixtral example. * Better image loading. * Support loading images embedded in safetensor files. * Clippy fixes. * Add the llava multimodal adapter. * Add more of the llava bits. * Add the pixtral config. * More pixtral inference. * Add the text generation bits. * Get the example to work. * Bugfix. * Run some bits of the model in f32. * Blessed version :) * Better rope frequency computations. * README update.
This commit is contained in:
@ -14,6 +14,7 @@ use std::sync::Arc;
|
||||
pub struct VarBuilderArgs<'a, B: Backend> {
|
||||
data: Arc<TensorData<B>>,
|
||||
path: Vec<String>,
|
||||
pub dtype: DType,
|
||||
_phantom: std::marker::PhantomData<&'a B>,
|
||||
}
|
||||
|
||||
@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path: self.path.clone(),
|
||||
dtype: self.dtype,
|
||||
_phantom: self._phantom,
|
||||
}
|
||||
}
|
||||
@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
|
||||
|
||||
struct TensorData<B: Backend> {
|
||||
backend: B,
|
||||
pub dtype: DType,
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
dtype,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path: vec![],
|
||||
dtype: self.dtype,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path: vec![prefix.to_string()],
|
||||
dtype: self.dtype,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path,
|
||||
dtype: self.dtype,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
|
||||
/// The dtype used by default.
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.data.dtype
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// Clone the VarBuilder tweaking its dtype
|
||||
pub fn to_dtype(&self, dtype: DType) -> Self {
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path: self.path.clone(),
|
||||
dtype,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn path(&self, tensor_name: &str) -> String {
|
||||
@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
name: &str,
|
||||
hints: B::Hints,
|
||||
) -> Result<Tensor> {
|
||||
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
|
||||
self.get_with_hints_dtype(s, name, hints, self.dtype)
|
||||
}
|
||||
|
||||
/// Retrieve the tensor associated with the given name at the current path.
|
||||
@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> {
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
) -> Self {
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
let data = TensorData { backend, device };
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
dtype,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> {
|
||||
let path = self.path.clone();
|
||||
let backend = Rename::new(self, renamer);
|
||||
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
let data = TensorData { backend, device };
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
dtype,
|
||||
path,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
|
Reference in New Issue
Block a user