Add Pixtral. (#2521)

* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
This commit is contained in:
Laurent Mazare
2024-09-30 19:31:14 +02:00
committed by GitHub
parent 2f49e1b534
commit 683ab698de
9 changed files with 822 additions and 19 deletions

View File

@ -14,6 +14,7 @@ use std::sync::Arc;
pub struct VarBuilderArgs<'a, B: Backend> {
data: Arc<TensorData<B>>,
path: Vec<String>,
pub dtype: DType,
_phantom: std::marker::PhantomData<&'a B>,
}
@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: self.path.clone(),
dtype: self.dtype,
_phantom: self._phantom,
}
}
@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
struct TensorData<B: Backend> {
backend: B,
pub dtype: DType,
pub device: Device,
}
@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
let data = TensorData {
backend,
dtype,
device: dev.clone(),
};
Self {
data: Arc::new(data),
path: vec![],
dtype,
_phantom: std::marker::PhantomData,
}
}
@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![],
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![prefix.to_string()],
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path,
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
/// The dtype used by default.
pub fn dtype(&self) -> DType {
self.data.dtype
self.dtype
}
/// Clone the VarBuilder tweaking its dtype
pub fn to_dtype(&self, dtype: DType) -> Self {
Self {
data: self.data.clone(),
path: self.path.clone(),
dtype,
_phantom: std::marker::PhantomData,
}
}
fn path(&self, tensor_name: &str) -> String {
@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
self.get_with_hints_dtype(s, name, hints, self.dtype)
}
/// Retrieve the tensor associated with the given name at the current path.
@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> {
dtype: DType,
device: Device,
) -> Self {
let data = TensorData {
backend,
dtype,
device,
};
let data = TensorData { backend, device };
Self {
data: Arc::new(data),
path: vec![],
dtype,
_phantom: std::marker::PhantomData,
}
}
@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> {
let path = self.path.clone();
let backend = Rename::new(self, renamer);
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
let data = TensorData {
backend,
dtype,
device,
};
let data = TensorData { backend, device };
Self {
data: Arc::new(data),
dtype,
path,
_phantom: std::marker::PhantomData,
}