microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2cf9bab611e9ad563822dee69c44e23bd017fadc

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/onnxprocess/_tensor.py

628lines · modecode

1import torch
2import builtins
3import functools
4import numpy as np
5from onnx import onnx_pb as onnx_proto
6from typing import List, Tuple, Optional, Union, Any, ContextManager, overload, Iterator, NamedTuple
7from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout # noqa
8from torch import strided, memory_format, contiguous_format, StringType # noqa
9
10from ._onnx_ops import ox as _ox
11from .._ortapi2 import OrtPyFunction
12
13
14class _EagerTensor:
15 def __init__(self, _t, name=None, sess=None, raw_data: Any = None):
16 self._t = _t if isinstance(_t, torch.Tensor) else torch.tensor(_t)
17 if isinstance(name, (tuple, list)):
18 assert len(name) == 1, "Multiple names for one tensor!"
19 name = name[0]
20 self.name = '' if name is None else name
21 self.raw_data = raw_data
22 self.symbolic_shape = []
23
24 def __repr__(self):
25 if self.raw_data is not None:
26 return "name: {}, \"{}\"".format(self.name, str(self.raw_data))
27 else:
28 return "name: {}, {}, dtype={}".format(self.name, repr(self._t), str(self._t.dtype))
29
30 _all_ops = {}
31
32 @property
33 def value(self) -> Union[torch.Tensor, Any]:
34 return self.raw_data if self.raw_data else self._t
35
36 @property
37 def t(self):
38 return self._t
39
40 @property
41 def dtype(self):
42 return self._t.dtype
43
44 @property
45 def onnx_type(self):
46 return self.to_onnx_type(self._t.dtype)
47
48 @classmethod
49 def is_numeric(cls, np_arr):
50 return np_arr.dtype.kind in set('buifc')
51
52 @classmethod
53 def set_active_session(cls, sess):
54 """
55 set the active operator tracing log session. if sess is None, the active session will be removed
56 :param sess:
57 :return:
58 """
59 if not hasattr(cls, '_active_session'):
60 cls._active_session = sess
61 if sess is None:
62 raise RuntimeError("unset the active session twice!")
63 else:
64 if sess is not None:
65 raise RuntimeError("The active session already assigned!")
66 delattr(cls, '_active_session')
67
68 @classmethod
69 def get_trace_session(cls):
70 if not hasattr(cls, '_active_session'):
71 raise RuntimeError("the tracing not started yet!")
72 return cls._active_session # noqa
73
74 @classmethod
75 def get_container(cls):
76 return cls.get_trace_session().container
77
78 @classmethod
79 def from_onnx(cls, raw_val, ort_sess, name):
80 raw_data = None
81 if cls.is_numeric(raw_val):
82 val = torch.from_numpy(raw_val)
83 else:
84 # only keep the shape and the value was stored by it-self.
85 val = torch.empty(*raw_val.shape, dtype=torch.uint8)
86 raw_data = raw_val
87 t = cls(val, name, ort_sess, raw_data)
88 return t
89
90 @classmethod
91 def from_torch(cls, _t, name):
92 t_name = name if name is not None else "id_{}".format(id(_t))
93 ts = cls(_t, t_name)
94 return ts
95
96 @classmethod
97 # torch.tensor prototype
98 def mytensor(cls, data: Any, dtype: Optional[_dtype] = None, device: Union[_device, str, None] = None, requires_grad: _bool = False): # noqa
99 y = torch.tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
100 val = _ox.make_tensor(cls.to_onnx_type(y.dtype), list(y.size()),
101 [data] if isinstance(data, (int, float, str, bool)) else data)
102 s = _ox.constant([], [_ox.get_unique_tensor_name('const')], cls.get_container(), None, value=val)
103 return cls.from_torch(y, s)
104
105 def numpy(self):
106 return self._t.numpy() if self.raw_data is None else self.raw_data
107
108 def item(self):
109 return self.numpy().item()
110
111 def get_shape(self):
112 return self.t.size() if len(self.symbolic_shape) == 0 else self.symbolic_shape
113
114 def _to_binary_tensor_args(self, other):
115 # convert self, other to [self, other], but if either is a number, convert that to a constant
116 x, y = self, other
117 if isinstance(y, (int, float, bool, np.ndarray)):
118 y = self.mytensor(y)
119 elif isinstance(x, (int, float, bool, np.ndarray)):
120 x = self.mytensor(x)
121 return x, y
122
123 _dup_id = 0
124
125 def __copy__(self):
126 new_t = _EagerTensor.from_torch(self.t, self.name + '_{}'.format(_EagerTensor._dup_id))
127 self._dup_id += 1
128 new_t.raw_data = self.raw_data
129 return new_t
130
131 def __add__(self, other):
132 x0, x1 = self._to_binary_tensor_args(other)
133 y = torch.add(x0._t, x1._t)
134 s = _ox.add(*_EagerTensor.ox_args([x0, x1]))
135 return self.from_torch(y, s)
136
137 def __sub__(self, other):
138 x0, x1 = self._to_binary_tensor_args(other)
139 y = torch.sub(x0._t, x1._t)
140 s = _ox.sub(*_EagerTensor.ox_args([x0, x1]))
141 return self.from_torch(y, s)
142
143 def __mul__(self, other):
144 x0, x1 = self._to_binary_tensor_args(other)
145 y = torch.mul(x0._t, x1._t)
146 s = _ox.mul(*_EagerTensor.ox_args([x0, x1]))
147 return self.from_torch(y, s)
148
149 def __div__(self, other):
150 x0, x1 = self._to_binary_tensor_args(other)
151 y = torch.div(x0._t, x1._t)
152 s = _ox.div(*_EagerTensor.ox_args([x0, x1]))
153 return self.from_torch(y, s)
154
155 def __pow__(self, other):
156 x0, x1 = self._to_binary_tensor_args(other)
157 y = torch.pow(x0._t, x1._t)
158 s = _ox.pow(*_EagerTensor.ox_args([x0, x1]))
159 return self.from_torch(y, s)
160
161 def __matmul__(self, other):
162 x0, x1 = self._to_binary_tensor_args(other)
163 y = torch.matmul(x0._t, x1._t)
164 s = _ox.matmul(*_EagerTensor.ox_args([x0, x1]))
165 return self.from_torch(y, s)
166
167 def __lt__(self, other):
168 x0, x1 = self._to_binary_tensor_args(other)
169 y = torch.less(x0._t, x1._t)
170 s = _ox.less(*_EagerTensor.ox_args([x0, x1]))
171 return self.from_torch(y, s)
172
173 def __le__(self, other):
174 x0, x1 = self._to_binary_tensor_args(other)
175 y = torch.less_equal(x0._t, x1._t)
176 s = _ox.less_equal(*_EagerTensor.ox_args([x0, x1]))
177 return self.from_torch(y, s)
178
179 def __eq__(self, other):
180 x0, x1 = self._to_binary_tensor_args(other)
181 y = torch.equal(x0._t, x1._t)
182 s = _ox.equal(*_EagerTensor.ox_args([x0, x1]))
183 return self.from_torch(y, s)
184
185 def __ne__(self, other):
186 x0, x1 = self._to_binary_tensor_args(other)
187 y = torch.not_equal(x0._t, x1._t)
188 s = _ox.not_equal(*_EagerTensor.ox_args([x0, x1]))
189 return self.from_torch(y, s)
190
191 def __gt__(self, other):
192 x0, x1 = self._to_binary_tensor_args(other)
193 y = torch.greater(x0._t, x1._t)
194 s = _ox.greater(*_EagerTensor.ox_args([x0, x1]))
195 return self.from_torch(y, s)
196
197 def __ge__(self, other):
198 x0, x1 = self._to_binary_tensor_args(other)
199 y = torch.greater_equal(x0._t, x1._t)
200 s = _ox.greater_equal(*_EagerTensor.ox_args([x0, x1]))
201 return self.from_torch(y, s)
202
203 def __invert__(self):
204 if self.t.dtype is torch.bool:
205 y = torch.logical_not(self.t)
206 s = _ox.not_op(*self.my_args())
207 return self.from_torch(y, s)
208 else:
209 raise NotImplementedError("no numeric tensor inverse supported yet.")
210
211 def __neg__(self):
212 y = torch.neg([self.t])
213 s = _ox.neg(*self.my_args())
214 return self.from_torch(y, s)
215
216 def __not__(self):
217 y = torch.logical_not(self.t)
218 s = _ox.not_op(*self.my_args())
219 return self.from_torch(y, s)
220
221 def __or__(self, other):
222 x0, x1 = self._to_binary_tensor_args(other)
223 y = torch.logical_or(x0._t, x1._t)
224 s = _ox.or_op(*_EagerTensor.ox_args([x0, x1]))
225 return self.from_torch(y, s)
226
227 def __getitem__(self, indices):
228 y = self.value.__getitem__(indices)
229
230 # normalize indices to tuples of slices
231 # Formats encountered:
232 # - a single int
233 # - a tuple of (int or slice)
234 if not isinstance(indices, (tuple, list)): # single item: make it a tuple
235 indices = (indices,)
236 squeeze = [axis for axis, index in enumerate(indices) if
237 isinstance(index, int)] # which axes had a single index?
238 indices = tuple(
239 index if isinstance(index, slice) else slice(index, index + 1 if index != -1 else None, 1) for index in
240 indices) # make all tuple items of type Slice
241 bs, es, ss, ds = [], [], [], []
242 INT_MAX = 2 ** 63 - 1
243 for axis, index in enumerate(indices):
244 if not isinstance(index, slice):
245 raise ValueError("Index expected")
246 if index.start is None and index.stop is None: # [:] can be skipped
247 continue
248 b, e, s = index.start, index.stop, index.step
249 bs.append(b if b is not None else 0)
250 es.append(e if e is not None else INT_MAX)
251 ss.append(s if s is not None else 1)
252 ds.append(axis)
253 s = _ox.slice(*self.my_args(), starts=bs, ends=es, axes=ds, steps=ss)
254 if squeeze: # single index means we must drop the axis
255 s = _ox.squeeze(*self.ox_name_args(s), axes=squeeze)
256
257 return self.from_torch(y, s)
258
259 def __getattribute__(self, attr):
260 """
261 A little hack that allows to call unary operators in a chaining fashion,
262 e.g. x.shape() instead of ox.shape(x).
263 """
264 if attr in _EagerTensor._all_ops:
265 f = _EagerTensor._all_ops[attr]
266 return functools.partial(f, self)
267 else:
268 return object.__getattribute__(self, attr)
269
270 @classmethod
271 def ox_name_args(cls, input_names, output_names=None):
272 """
273 generate the arguments for ONNX model builder.
274 :param input_names: input name list
275 :param output_names: output name list, can be None, or [None]*output_n
276 :return: input_names, output_names, container, operator_name
277 """
278 container = cls.get_trace_session().container
279 if output_names is None:
280 output_names = [None] # by default, there is only one output
281
282 output_names = [_ox.get_unique_tensor_name(str(n_))
283 if output_names[n_] is None else
284 output_names[n_] for n_ in range(len(output_names))]
285 operator_name = None
286 return input_names, output_names, container, operator_name
287
288 @classmethod
289 def ort_verify(cls, ts_from, ts_to):
290 result, model = cls.get_trace_session().runops(ts_from, ts_to)
291 for idx in range(len(ts_to)):
292 if not np.allclose(ts_to[idx].numpy(), result[idx]):
293 # ONNX cannot be import globally, which is conflict with torch.onnx
294 import onnx # noqa
295 onnx.save_model(model, 'mt_debmodel.onnx')
296 raise RuntimeError("ONNXRuntime Result is not same pytorch!")
297
298 def create_and_verify(self, value, name, additional_inputs=None):
299 ts_y = self.from_torch(value, name)
300 inputs = [self] + ([] if additional_inputs is None else additional_inputs)
301 self.ort_verify(inputs, [ts_y])
302 return ts_y
303
304 @classmethod
305 def ox_args(cls, tensors, output_names=None):
306 input_names = [ts_ if isinstance(ts_, str) else ts_.name for ts_ in tensors]
307 return cls.ox_name_args(input_names, output_names)
308
309 def my_args(self):
310 return self.ox_args([self])
311
312 @staticmethod
313 def normalize_seq(list_or_tuple):
314 return [x.value.item() if isinstance(x, _EagerTensor) else x for x in list_or_tuple]
315
316 @staticmethod
317 def to_onnx_type(torch_type):
318 ty_dict = {torch.bool: onnx_proto.TensorProto.BOOL,
319 torch.float32: onnx_proto.TensorProto.FLOAT,
320 torch.long: onnx_proto.TensorProto.INT64,
321 torch.int32: onnx_proto.TensorProto.INT32}
322 # ...
323 return ty_dict.get(torch_type, onnx_proto.TensorProto.STRING)
324
325 def long(self):
326 y = self._t.long()
327 s = _ox.cast(*self.my_args(), to=onnx_proto.TensorProto.INT64)
328 return self.create_and_verify(y, s[0])
329
330 def cumsum(self, dim: _int, *, dtype: Optional[_dtype] = None): # noqa
331 y = self._t.cumsum(dim, dtype=dtype)
332 s = _ox.cumsum(*self.my_args(), axis=dim)
333 return self.create_and_verify(y, s[0])
334
335 def size(self):
336 y = self._t.size()
337 s = _ox.shape(*self.my_args())
338 return self.create_and_verify(y, s[0])
339
340 def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False):
341 y = self._t.type(dtype, non_blocking)
342 s = _ox.cast(*self.my_args(), to=self.to_onnx_type(dtype))
343 return self.create_and_verify(y, s)
344
345 def to(self, device):
346 y = self._t.to(device)
347 s = _ox.identity(*self.my_args())
348 return self.create_and_verify(y, s[0])
349
350 def cpu(self):
351 y = self._t.cpu()
352 s = _ox.identity(*self.my_args())
353 return self.create_and_verify(y, s[0])
354
355 def detach(self):
356 y = self._t.detach()
357 s = _ox.identity(*self.my_args())
358 return self.create_and_verify(y, s[0])
359
360 def clone(self):
361 y = self._t.clone()
362 s = _ox.identity(*self.my_args())
363 return self.create_and_verify(y, s[0])
364
365 def masked_fill(self, mask, value):
366 y = self._t.masked_fill(mask.value, value)
367 if not isinstance(value, _EagerTensor):
368 value = _EagerTensor.mytensor(value)
369 s = _ox.where(*_EagerTensor.ox_args([mask, value, self]))
370 return self.create_and_verify(y, s[0], additional_inputs=[mask, value])
371
372 def unsqueeze(self, dim: _int):
373 y = self._t.unsqueeze(dim)
374 s = _ox.unsqueeze(*self.my_args(), [dim])
375 return self.create_and_verify(y, s[0])
376
377 def squeeze(self, dim: _int):
378 y = self._t.squeeze(dim)
379 s = _ox.squeeze(*self.my_args(), [dim])
380 return self.create_and_verify(y, s[0])
381
382
383def _create_ox_sequence(*size):
384 container = _EagerTensor.get_container()
385 con_x = []
386 if builtins.any(isinstance(n_, _EagerTensor) for n_ in size):
387 for x in size:
388 if isinstance(x, _EagerTensor):
389 x_h = _ox.unsqueeze(*_EagerTensor.ox_args([x]))[0]
390 else:
391 x_c = _ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [x])
392 x_h = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=x_c)[0]
393 con_x.append(x_h)
394 return _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None)
395 else:
396 ts_size = _ox.make_tensor(onnx_proto.TensorProto.INT64, [len(size)], size)
397 return _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size)
398
399
400def _create_ox_sequence_constant(*size, init_value=None, onnx_type=None):
401 if onnx_type is None:
402 onnx_type = onnx_proto.TensorProto.FLOAT
403 names = _create_ox_sequence(*size)
404 ts_val = _ox.make_tensor(onnx_type, [1], [init_value])
405
406 container = _EagerTensor.get_container()
407 s = _ox.constant_of_shape(names, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val)
408 return s[0]
409
410
411def empty(*size: Union[_int, _EagerTensor], memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None,
412 dtype: _dtype = None, layout: _layout = strided, device: Union[_device, str, None] = None,
413 requires_grad: _bool = False) -> _EagerTensor: # noqa
414
415 if len(size) == 1 and isinstance(size[0], list):
416 size = size[0]
417 n_size = _EagerTensor.normalize_seq(size)
418 y = torch.empty(*n_size, memory_format=memory_format, out=out,
419 dtype=dtype, layout=layout, device=device, requires_grad=requires_grad)
420 s = _create_ox_sequence_constant(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype))
421 return _EagerTensor.from_torch(y, s)
422
423
424def zeros(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
425 device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
426
427 if len(size) == 1 and isinstance(size[0], list):
428 size = size[0]
429 n_size = _EagerTensor.normalize_seq(size)
430 y = torch.zeros(*n_size, out=out, dtype=dtype,
431 layout=layout, device=device, requires_grad=requires_grad)
432 s = _create_ox_sequence_constant(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
433 return _EagerTensor.from_torch(y, s)
434
435
436def ones(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided,
437 device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa
438
439 if len(size) == 1 and isinstance(size[0], list):
440 size = size[0]
441 n_size = _EagerTensor.normalize_seq(size)
442 y = torch.ones(*n_size, out=out, dtype=dtype,
443 layout=layout, device=device, requires_grad=requires_grad)
444 s = _create_ox_sequence_constant(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype))
445 return _EagerTensor.from_torch(y, s)
446
447
448def repeat(input_ts: _EagerTensor, *repeats: Union[_int, _EagerTensor]) -> _EagerTensor: # noqa
449
450 if len(repeats) == 1 and isinstance(repeats[0], list):
451 repeats = repeats[0]
452 n_size = _EagerTensor.normalize_seq(repeats)
453 y = input_ts.t.repeat(*n_size)
454 seq = _create_ox_sequence(*repeats)
455 s = _ox.tile(*input_ts.my_args(), repeats=seq[0])
456 return _EagerTensor.from_torch(y, s[0])
457
458
459def argmax(input_ts: _EagerTensor, dim: Optional[_int] = None, keepdim: _bool = False) -> _EagerTensor: # noqa
460 y = torch.argmax(input_ts.value, dim, keepdim)
461 s = _ox.argmax(*input_ts.my_args(), axis=dim, keepdims=keepdim)
462 return _EagerTensor.from_torch(y, s)
463
464
465def softmax(input_ts: _EagerTensor, dim: _int, dtype: Optional[_dtype]=None) -> _EagerTensor:
466 y = torch.softmax(input_ts.value, dim, dtype)
467 s = _ox.softmax(*input_ts.my_args(), axis=dim)
468 return _EagerTensor.from_torch(y, s)
469
470
471def cat(tensors: Union[Tuple[_EagerTensor, ...], List[_EagerTensor]],
472 dim, *, out: Optional[_EagerTensor] = None) -> _EagerTensor: # noqa
473 res = torch.cat([t_.value for t_ in tensors], dim, out=out)
474 oname = _ox.concat(*_EagerTensor.ox_args(tensors), dim)
475 y = _EagerTensor.from_torch(res, oname[0])
476 _EagerTensor.ort_verify(tensors, [y])
477 return y
478
479
480def all(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
481 container = _EagerTensor.get_container()
482 y = torch.all(input_ts.value)
483 s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
484 s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[-1])
485 s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
486 container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
487 s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
488 return input_ts.create_and_verify(y, s[0])
489
490
491def any(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa
492 container = _EagerTensor.get_container()
493 y = torch.any(input_ts.value)
494 s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64)
495 s_redm = _ox.reducesum(s_casted, [_ox.get_unique_tensor_name('reducesum')], container, None, axes=[-1])
496 s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')],
497 container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0]))
498 s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None)
499 return input_ts.create_and_verify(y, s[0])
500
501
502def reshape(input_ts: _EagerTensor, shape: _size):
503 y = input_ts.t.reshape(shape)
504 s = _ox.reshape(*input_ts.my_args(), desired_shape=shape)
505 return input_ts.create_and_verify(y, s[0])
506
507
508def transpose(input_ts: _EagerTensor, dim0: _int, dim1: _int):
509 y = input_ts.t.transpose(dim0, dim1)
510 axes = list(range(y.dim()))
511 axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
512 s = _ox.transpose(*input_ts.my_args(), perm=axes)
513 return input_ts.create_and_verify(y, s[0])
514
515
516class _LoopIterator:
517 def __init__(self, ctx):
518 self.context = ctx
519
520 def __iter__(self):
521 return self
522
523 def __next__(self):
524 if self.context.is_stopped():
525 _EagerTensor.get_trace_session().pop_container()
526 raise StopIteration
527 return self.context.current()
528
529
530class _ControlFlowContext:
531 def __init__(self):
532 self.condition_i = None
533 self.condition = None
534 self.loop_count = None
535 self.iteration_num = None
536 self.states_i = []
537 self.loop_states = []
538 self.scan_outputs = []
539 self.sub_graph = None
540
541 def flow_output(self, cond, *outputs):
542 assert len(outputs) >= len(self.loop_states), "The loop body doesn't return enough objects"
543 if self.sub_graph is None:
544 trc = _EagerTensor.get_trace_session()
545 self.sub_graph = trc.build_graph(trc.container,
546 [self.iteration_num, self.condition] + self.loop_states,
547 [cond] + list(outputs))
548
549 self.condition = cond
550 c_state = len(self.loop_states)
551 self.loop_states = list(outputs[:c_state])
552 if len(self.scan_outputs) == 0:
553 sc = [_EagerTensor(torch.unsqueeze(sci_.value, 0), 'sc_' + sci_.name) for sci_ in outputs[c_state:]]
554 self.scan_outputs = sc
555 else:
556 next_extra_vars = []
557 for idx_, ext_ in enumerate(outputs[c_state:]):
558 et = self.scan_outputs[idx_]
559 next_extra_vars.append(_EagerTensor(
560 torch.cat([et.value, torch.unsqueeze(outputs[c_state + idx_].value, 0)]), name=et.name))
561 self.scan_outputs = next_extra_vars
562 self.iteration_num.value.add_(1)
563
564 def current(self):
565 return [self.iteration_num] + list(self.loop_states)
566
567 def finalize(self):
568 # generate the outputs from the enclosing scope variables
569 full_outputs = [_EagerTensor(o_.value, 'lp_' + o_.name) for o_ in self.loop_states + self.scan_outputs]
570 _ox.loop(*_EagerTensor.ox_args(
571 [self.loop_count, self.condition_i] + list(self.states_i),
572 [ts_.name for ts_ in full_outputs]), body=self.sub_graph)
573 return tuple(full_outputs)
574
575 def is_stopped(self):
576 return self.condition.item() is False or self.iteration_num.item() >= self.loop_count.item()
577
578 def loop(self, loop_c, condition, *states):
579 self.condition = condition
580 self.condition_i = condition
581 self.states_i = states
582 _EagerTensor.get_trace_session().stack_container()
583 self.iteration_num = _EagerTensor.mytensor(0)
584 # clone the variables for the sub graph.
585 self.loop_states = [_EagerTensor(st_.value, st_.name) for st_ in states]
586 self.loop_count = loop_c
587 loop_b = _LoopIterator(self)
588 return iter(loop_b)
589
590
591def control_flow():
592 return _ControlFlowContext()
593
594
595class _TracingEagerOp(OrtPyFunction):
596 def __call__(self, *args, **kwargs):
597 np_args = [ts_.numpy() if isinstance(ts_, _EagerTensor) else ts_ for ts_ in args]
598 outseq = super().__call__(*np_args, **kwargs)
599 outseq = outseq if isinstance(outseq, (list, tuple)) else [outseq]
600
601 outputs = [_EagerTensor.from_onnx(outseq[n_], self.ort_session, out_.name)
602 for n_, out_ in enumerate(self.ort_session.get_outputs())]
603
604 y_names = [y.name for y in outputs]
605 _ox.model_call(*_EagerTensor.ox_args(args, output_names=y_names), oxml=self.onnx_model)
606 return tuple(outputs) if len(outputs) > 1 else outputs[0]
607
608
609def op_from_customop(op_type, *args, **kwargs) -> _TracingEagerOp:
610 return _TracingEagerOp.from_customop(op_type, *args, **kwargs)
611
612
613def op_from_model(path_or_model, *args, **kwargs) -> _TracingEagerOp:
614 return _TracingEagerOp.from_model(path_or_model, *args, **kwargs)
615
616
617_EagerTensor._all_ops = {'argmax': argmax,
618 'softmax': softmax,
619 'reshape': reshape,
620 'transpose': transpose,
621 'repeat': repeat,
622 'any': any,
623 'all': all}
624
625tensor = _EagerTensor.mytensor
626tensor_from_onnx = _EagerTensor.from_onnx
627tensor_from_torch = _EagerTensor.from_torch
628tensor_set_session = _EagerTensor.set_active_session
629