PyO3: Add optional candle.onnx module (#1282)

* Start onnx integration

* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx

* Implement ONNXModel

* `fmt`

* add `onnx` flag to python ci

* Pin `protoc` to `25.0`

* Setup `protoc` in wheel builds

* Build wheels with `onnx`

* Install `protoc` in manylinux containers

* `apt` -> `yum`

* Download `protoc` via bash script

* Back to `manylinux: auto`

* Disable `onnx` builds for linux
This commit is contained in:
Lukas Kreussel
2023-11-08 06:37:50 +01:00
committed by GitHub
parent 7920b45c8a
commit f3a4f3db76
10 changed files with 343 additions and 6 deletions

View File

@ -0,0 +1,5 @@
# Generated content DO NOT EDIT
from .. import onnx
ONNXModel = onnx.ONNXModel
ONNXTensorDescription = onnx.ONNXTensorDescription

View File

@ -0,0 +1,89 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
class ONNXModel:
"""
A wrapper around an ONNX model.
"""
def __init__(self, path: str):
pass
@property
def doc_string(self) -> str:
"""
The doc string of the model.
"""
pass
@property
def domain(self) -> str:
"""
The domain of the operator set of the model.
"""
pass
def initializers(self) -> Dict[str, Tensor]:
"""
Get the weights of the model.
"""
pass
@property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The inputs of the model.
"""
pass
@property
def ir_version(self) -> int:
"""
The version of the IR this model targets.
"""
pass
@property
def model_version(self) -> int:
"""
The version of the model.
"""
pass
@property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The outputs of the model.
"""
pass
@property
def producer_name(self) -> str:
"""
The producer of the model.
"""
pass
@property
def producer_version(self) -> str:
"""
The version of the producer of the model.
"""
pass
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the model on the given inputs.
"""
pass
class ONNXTensorDescription:
"""
A wrapper around an ONNX tensor description.
"""
@property
def dtype(self) -> DType:
"""
The data type of the tensor.
"""
pass
@property
def shape(self) -> Tuple[Union[int, str, Any]]:
"""
The shape of the tensor.
"""
pass