PyO3: Add optional candle.onnx module (#1282)

* Start onnx integration

* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx

* Implement ONNXModel

* `fmt`

* add `onnx` flag to python ci

* Pin `protoc` to `25.0`

* Setup `protoc` in wheel builds

* Build wheels with `onnx`

* Install `protoc` in manylinux containers

* `apt` -> `yum`

* Download `protoc` via bash script

* Back to `manylinux: auto`

* Disable `onnx` builds for linux
This commit is contained in:
Lukas Kreussel
2023-11-08 06:37:50 +01:00
committed by GitHub
parent 7920b45c8a
commit f3a4f3db76
10 changed files with 343 additions and 6 deletions

View File

@ -19,12 +19,14 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
mod utils;
use utils::wrap_err;
mod shape;
use shape::{PyShape, PyShapeWithHole};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[cfg(feature = "onnx")]
mod onnx;
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}
#[cfg(feature = "onnx")]
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
Ok(())
}
#[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
#[cfg(feature = "onnx")]
{
let onnx = PyModule::new(py, "onnx")?;
candle_onnx_m(py, onnx)?;
m.add_submodule(onnx)?;
}
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;