microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
skottmckay/BuildInfra_AndTestImageLibs

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/pnp/_base.py

119lines · modecode

1import io
2import onnx
3import torch
4from typing import Any
5from onnx.onnx_pb import TensorProto
6from torch.onnx import TrainingMode, export as _export
7
8
9def _export_f(model, *args,
10 opset_version=None,
11 output_path=None,
12 output_seq=0,
13 export_params=True,
14 verbose=False,
15 input_names=None,
16 output_names=None,
17 operator_export_type=None,
18 do_constant_folding=True,
19 dynamic_axes=None,
20 keep_initializers_as_inputs=None,
21 custom_opsets=None):
22
23 with io.BytesIO() as f:
24 _export(model, args, f,
25 export_params=export_params, verbose=verbose,
26 training=TrainingMode.EVAL, input_names=input_names,
27 output_names=output_names,
28 operator_export_type=operator_export_type, opset_version=opset_version,
29 do_constant_folding=do_constant_folding,
30 dynamic_axes=dynamic_axes,
31 keep_initializers_as_inputs=keep_initializers_as_inputs,
32 custom_opsets=custom_opsets)
33
34 mdl = onnx.load_model(io.BytesIO(f.getvalue()))
35 if output_path is not None:
36 if output_seq > 0:
37 output_path.replace('.onnx', '.{}.onnx'.format(output_seq))
38 onnx.save_model(mdl, output_path)
39 return mdl
40
41
42class _ProcessingModule:
43
44 def __init__(self):
45 super(_ProcessingModule, self).__init__()
46 _ProcessingModule.register_customops()
47
48 @staticmethod
49 @torch.jit.unused
50 def _argsort(g, x, dim, descending):
51 return g.op('ai.onnx.contrib::ArgSort', x, dim)
52
53 @classmethod
54 @torch.jit.unused
55 def register_customops(cls):
56 if hasattr(cls, 'loaded'):
57 return True
58
59 torch.onnx.register_custom_op_symbolic('::argsort', cls._argsort, 1)
60 # ... more
61
62 cls.loaded = True
63 return True
64
65 @torch.jit.unused
66 def export(self, *args, opset_version=None, script_mode=False, output_path=None, output_seq=0, **kwargs):
67 if opset_version is None:
68 raise RuntimeError('No opset_version found in the kwargs.')
69 mod = self
70 if script_mode and not isinstance(mod, torch.jit.ScriptModule):
71 mod = torch.jit.script(mod)
72
73 return _export_f(mod,
74 *args,
75 opset_version=opset_version,
76 output_path=output_path,
77 output_seq=output_seq, **kwargs)
78
79
80class ProcessingTracedModule(torch.nn.Module, _ProcessingModule):
81 def __init__(self, func_obj=None):
82 super().__init__()
83 self.func_obj = func_obj
84
85 def forward(self, *args):
86 assert self.func_obj is not None, "No forward method found."
87 return self.func_obj(*args)
88
89
90class ProcessingScriptModule(torch.nn.Module, _ProcessingModule):
91
92 @torch.jit.unused
93 def export(self, *args, **kwargs):
94 return super().export(*args, script_mode=True, **kwargs)
95
96
97class CustomFunction(torch.autograd.Function):
98 @staticmethod
99 def jvp(ctx: Any, *grad_inputs: Any) -> Any:
100 pass
101
102 @staticmethod
103 def backward(ctx: Any, *grad_outputs: Any) -> Any:
104 return grad_outputs
105
106 @classmethod
107 def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
108 pass
109
110 @classmethod
111 def symbolic(cls, g, *args):
112 return g.op('ai.onnx.contrib::' + cls.__name__, *args)
113
114
115tensor_data_type = TensorProto
116
117
118def is_processing_module(m):
119 return isinstance(m, _ProcessingModule)
120