microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2cf9bab611e9ad563822dee69c44e23bd017fadc

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/pnp/_base.py

124lines · modecode

1import io
2import onnx
3import torch
4from typing import Any
5from onnx.onnx_pb import TensorProto
6from torch.onnx import TrainingMode, export as _export
7
8from ._onnx_ops import OPSET_TO_IR_VERSION
9
10
11def _export_f(model, *args,
12 opset_version=None,
13 output_path=None,
14 output_seq=0,
15 export_params=True,
16 verbose=False,
17 input_names=None,
18 output_names=None,
19 operator_export_type=None,
20 do_constant_folding=True,
21 dynamic_axes=None,
22 keep_initializers_as_inputs=None,
23 custom_opsets=None):
24
25 with io.BytesIO() as f:
26 _export(model, args, f,
27 export_params=export_params, verbose=verbose,
28 training=TrainingMode.EVAL, input_names=input_names,
29 output_names=output_names,
30 operator_export_type=operator_export_type, opset_version=opset_version,
31 do_constant_folding=do_constant_folding,
32 dynamic_axes=dynamic_axes,
33 keep_initializers_as_inputs=keep_initializers_as_inputs,
34 custom_opsets=custom_opsets)
35
36 mdl = onnx.load_model(io.BytesIO(f.getvalue()))
37 for ops in mdl.opset_import:
38 if ops.domain in ('', 'ai.onnx'):
39 mdl.ir_version = OPSET_TO_IR_VERSION[ops.version]
40 if output_path is not None:
41 if output_seq > 0:
42 output_path.replace('.onnx', '.{}.onnx'.format(output_seq))
43 onnx.save_model(mdl, output_path)
44 return mdl
45
46
47class _ProcessingModule:
48
49 def __init__(self):
50 super(_ProcessingModule, self).__init__()
51 _ProcessingModule.register_customops()
52
53 @staticmethod
54 @torch.jit.unused
55 def _argsort(g, x, dim, descending):
56 return g.op('ai.onnx.contrib::ArgSort', x, dim)
57
58 @classmethod
59 @torch.jit.unused
60 def register_customops(cls):
61 if hasattr(cls, 'loaded'):
62 return True
63
64 torch.onnx.register_custom_op_symbolic('::argsort', cls._argsort, 1)
65 # ... more
66
67 cls.loaded = True
68 return True
69
70 @torch.jit.unused
71 def export(self, *args, opset_version=None, script_mode=False, output_path=None, output_seq=0, **kwargs):
72 if opset_version is None:
73 raise RuntimeError('No opset_version found in the kwargs.')
74 mod = self
75 if script_mode and not isinstance(mod, torch.jit.ScriptModule):
76 mod = torch.jit.script(mod)
77
78 return _export_f(mod,
79 *args,
80 opset_version=opset_version,
81 output_path=output_path,
82 output_seq=output_seq, **kwargs)
83
84
85class ProcessingTracedModule(torch.nn.Module, _ProcessingModule):
86 def __init__(self, func_obj=None):
87 super().__init__()
88 self.func_obj = func_obj
89
90 def forward(self, *args):
91 assert self.func_obj is not None, "No forward method found."
92 return self.func_obj(*args)
93
94
95class ProcessingScriptModule(torch.nn.Module, _ProcessingModule):
96
97 @torch.jit.unused
98 def export(self, *args, **kwargs):
99 return super().export(*args, script_mode=True, **kwargs)
100
101
102class CustomFunction(torch.autograd.Function):
103 @staticmethod
104 def jvp(ctx: Any, *grad_inputs: Any) -> Any:
105 pass
106
107 @staticmethod
108 def backward(ctx: Any, *grad_outputs: Any) -> Any:
109 return grad_outputs
110
111 @classmethod
112 def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
113 pass
114
115 @classmethod
116 def symbolic(cls, g, *args):
117 return g.op('ai.onnx.contrib::' + cls.__name__, *args)
118
119
120tensor_data_type = TensorProto
121
122
123def is_processing_module(m):
124 return isinstance(m, _ProcessingModule)
125