microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/onnxprocess/_tensor.py
628lines · modecode
| 1 | import torch |
| 2 | import builtins |
| 3 | import functools |
| 4 | import numpy as np |
| 5 | from onnx import onnx_pb as onnx_proto |
| 6 | from typing import List, Tuple, Optional, Union, Any, ContextManager, overload, Iterator, NamedTuple |
| 7 | from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout # noqa |
| 8 | from torch import strided, memory_format, contiguous_format, StringType # noqa |
| 9 | |
| 10 | from ._onnx_ops import ox as _ox |
| 11 | from .._ortapi2 import OrtPyFunction |
| 12 | |
| 13 | |
| 14 | class _EagerTensor: |
| 15 | def __init__(self, _t, name=None, sess=None, raw_data: Any = None): |
| 16 | self._t = _t if isinstance(_t, torch.Tensor) else torch.tensor(_t) |
| 17 | if isinstance(name, (tuple, list)): |
| 18 | assert len(name) == 1, "Multiple names for one tensor!" |
| 19 | name = name[0] |
| 20 | self.name = '' if name is None else name |
| 21 | self.raw_data = raw_data |
| 22 | self.symbolic_shape = [] |
| 23 | |
| 24 | def __repr__(self): |
| 25 | if self.raw_data is not None: |
| 26 | return "name: {}, \"{}\"".format(self.name, str(self.raw_data)) |
| 27 | else: |
| 28 | return "name: {}, {}, dtype={}".format(self.name, repr(self._t), str(self._t.dtype)) |
| 29 | |
| 30 | _all_ops = {} |
| 31 | |
| 32 | @property |
| 33 | def value(self) -> Union[torch.Tensor, Any]: |
| 34 | return self.raw_data if self.raw_data else self._t |
| 35 | |
| 36 | @property |
| 37 | def t(self): |
| 38 | return self._t |
| 39 | |
| 40 | @property |
| 41 | def dtype(self): |
| 42 | return self._t.dtype |
| 43 | |
| 44 | @property |
| 45 | def onnx_type(self): |
| 46 | return self.to_onnx_type(self._t.dtype) |
| 47 | |
| 48 | @classmethod |
| 49 | def is_numeric(cls, np_arr): |
| 50 | return np_arr.dtype.kind in set('buifc') |
| 51 | |
| 52 | @classmethod |
| 53 | def set_active_session(cls, sess): |
| 54 | """ |
| 55 | set the active operator tracing log session. if sess is None, the active session will be removed |
| 56 | :param sess: |
| 57 | :return: |
| 58 | """ |
| 59 | if not hasattr(cls, '_active_session'): |
| 60 | cls._active_session = sess |
| 61 | if sess is None: |
| 62 | raise RuntimeError("unset the active session twice!") |
| 63 | else: |
| 64 | if sess is not None: |
| 65 | raise RuntimeError("The active session already assigned!") |
| 66 | delattr(cls, '_active_session') |
| 67 | |
| 68 | @classmethod |
| 69 | def get_trace_session(cls): |
| 70 | if not hasattr(cls, '_active_session'): |
| 71 | raise RuntimeError("the tracing not started yet!") |
| 72 | return cls._active_session # noqa |
| 73 | |
| 74 | @classmethod |
| 75 | def get_container(cls): |
| 76 | return cls.get_trace_session().container |
| 77 | |
| 78 | @classmethod |
| 79 | def from_onnx(cls, raw_val, ort_sess, name): |
| 80 | raw_data = None |
| 81 | if cls.is_numeric(raw_val): |
| 82 | val = torch.from_numpy(raw_val) |
| 83 | else: |
| 84 | # only keep the shape and the value was stored by it-self. |
| 85 | val = torch.empty(*raw_val.shape, dtype=torch.uint8) |
| 86 | raw_data = raw_val |
| 87 | t = cls(val, name, ort_sess, raw_data) |
| 88 | return t |
| 89 | |
| 90 | @classmethod |
| 91 | def from_torch(cls, _t, name): |
| 92 | t_name = name if name is not None else "id_{}".format(id(_t)) |
| 93 | ts = cls(_t, t_name) |
| 94 | return ts |
| 95 | |
| 96 | @classmethod |
| 97 | # torch.tensor prototype |
| 98 | def mytensor(cls, data: Any, dtype: Optional[_dtype] = None, device: Union[_device, str, None] = None, requires_grad: _bool = False): # noqa |
| 99 | y = torch.tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) |
| 100 | val = _ox.make_tensor(cls.to_onnx_type(y.dtype), list(y.size()), |
| 101 | [data] if isinstance(data, (int, float, str, bool)) else data) |
| 102 | s = _ox.constant([], [_ox.get_unique_tensor_name('const')], cls.get_container(), None, value=val) |
| 103 | return cls.from_torch(y, s) |
| 104 | |
| 105 | def numpy(self): |
| 106 | return self._t.numpy() if self.raw_data is None else self.raw_data |
| 107 | |
| 108 | def item(self): |
| 109 | return self.numpy().item() |
| 110 | |
| 111 | def get_shape(self): |
| 112 | return self.t.size() if len(self.symbolic_shape) == 0 else self.symbolic_shape |
| 113 | |
| 114 | def _to_binary_tensor_args(self, other): |
| 115 | # convert self, other to [self, other], but if either is a number, convert that to a constant |
| 116 | x, y = self, other |
| 117 | if isinstance(y, (int, float, bool, np.ndarray)): |
| 118 | y = self.mytensor(y) |
| 119 | elif isinstance(x, (int, float, bool, np.ndarray)): |
| 120 | x = self.mytensor(x) |
| 121 | return x, y |
| 122 | |
| 123 | _dup_id = 0 |
| 124 | |
| 125 | def __copy__(self): |
| 126 | new_t = _EagerTensor.from_torch(self.t, self.name + '_{}'.format(_EagerTensor._dup_id)) |
| 127 | self._dup_id += 1 |
| 128 | new_t.raw_data = self.raw_data |
| 129 | return new_t |
| 130 | |
| 131 | def __add__(self, other): |
| 132 | x0, x1 = self._to_binary_tensor_args(other) |
| 133 | y = torch.add(x0._t, x1._t) |
| 134 | s = _ox.add(*_EagerTensor.ox_args([x0, x1])) |
| 135 | return self.from_torch(y, s) |
| 136 | |
| 137 | def __sub__(self, other): |
| 138 | x0, x1 = self._to_binary_tensor_args(other) |
| 139 | y = torch.sub(x0._t, x1._t) |
| 140 | s = _ox.sub(*_EagerTensor.ox_args([x0, x1])) |
| 141 | return self.from_torch(y, s) |
| 142 | |
| 143 | def __mul__(self, other): |
| 144 | x0, x1 = self._to_binary_tensor_args(other) |
| 145 | y = torch.mul(x0._t, x1._t) |
| 146 | s = _ox.mul(*_EagerTensor.ox_args([x0, x1])) |
| 147 | return self.from_torch(y, s) |
| 148 | |
| 149 | def __div__(self, other): |
| 150 | x0, x1 = self._to_binary_tensor_args(other) |
| 151 | y = torch.div(x0._t, x1._t) |
| 152 | s = _ox.div(*_EagerTensor.ox_args([x0, x1])) |
| 153 | return self.from_torch(y, s) |
| 154 | |
| 155 | def __pow__(self, other): |
| 156 | x0, x1 = self._to_binary_tensor_args(other) |
| 157 | y = torch.pow(x0._t, x1._t) |
| 158 | s = _ox.pow(*_EagerTensor.ox_args([x0, x1])) |
| 159 | return self.from_torch(y, s) |
| 160 | |
| 161 | def __matmul__(self, other): |
| 162 | x0, x1 = self._to_binary_tensor_args(other) |
| 163 | y = torch.matmul(x0._t, x1._t) |
| 164 | s = _ox.matmul(*_EagerTensor.ox_args([x0, x1])) |
| 165 | return self.from_torch(y, s) |
| 166 | |
| 167 | def __lt__(self, other): |
| 168 | x0, x1 = self._to_binary_tensor_args(other) |
| 169 | y = torch.less(x0._t, x1._t) |
| 170 | s = _ox.less(*_EagerTensor.ox_args([x0, x1])) |
| 171 | return self.from_torch(y, s) |
| 172 | |
| 173 | def __le__(self, other): |
| 174 | x0, x1 = self._to_binary_tensor_args(other) |
| 175 | y = torch.less_equal(x0._t, x1._t) |
| 176 | s = _ox.less_equal(*_EagerTensor.ox_args([x0, x1])) |
| 177 | return self.from_torch(y, s) |
| 178 | |
| 179 | def __eq__(self, other): |
| 180 | x0, x1 = self._to_binary_tensor_args(other) |
| 181 | y = torch.equal(x0._t, x1._t) |
| 182 | s = _ox.equal(*_EagerTensor.ox_args([x0, x1])) |
| 183 | return self.from_torch(y, s) |
| 184 | |
| 185 | def __ne__(self, other): |
| 186 | x0, x1 = self._to_binary_tensor_args(other) |
| 187 | y = torch.not_equal(x0._t, x1._t) |
| 188 | s = _ox.not_equal(*_EagerTensor.ox_args([x0, x1])) |
| 189 | return self.from_torch(y, s) |
| 190 | |
| 191 | def __gt__(self, other): |
| 192 | x0, x1 = self._to_binary_tensor_args(other) |
| 193 | y = torch.greater(x0._t, x1._t) |
| 194 | s = _ox.greater(*_EagerTensor.ox_args([x0, x1])) |
| 195 | return self.from_torch(y, s) |
| 196 | |
| 197 | def __ge__(self, other): |
| 198 | x0, x1 = self._to_binary_tensor_args(other) |
| 199 | y = torch.greater_equal(x0._t, x1._t) |
| 200 | s = _ox.greater_equal(*_EagerTensor.ox_args([x0, x1])) |
| 201 | return self.from_torch(y, s) |
| 202 | |
| 203 | def __invert__(self): |
| 204 | if self.t.dtype is torch.bool: |
| 205 | y = torch.logical_not(self.t) |
| 206 | s = _ox.not_op(*self.my_args()) |
| 207 | return self.from_torch(y, s) |
| 208 | else: |
| 209 | raise NotImplementedError("no numeric tensor inverse supported yet.") |
| 210 | |
| 211 | def __neg__(self): |
| 212 | y = torch.neg([self.t]) |
| 213 | s = _ox.neg(*self.my_args()) |
| 214 | return self.from_torch(y, s) |
| 215 | |
| 216 | def __not__(self): |
| 217 | y = torch.logical_not(self.t) |
| 218 | s = _ox.not_op(*self.my_args()) |
| 219 | return self.from_torch(y, s) |
| 220 | |
| 221 | def __or__(self, other): |
| 222 | x0, x1 = self._to_binary_tensor_args(other) |
| 223 | y = torch.logical_or(x0._t, x1._t) |
| 224 | s = _ox.or_op(*_EagerTensor.ox_args([x0, x1])) |
| 225 | return self.from_torch(y, s) |
| 226 | |
| 227 | def __getitem__(self, indices): |
| 228 | y = self.value.__getitem__(indices) |
| 229 | |
| 230 | # normalize indices to tuples of slices |
| 231 | # Formats encountered: |
| 232 | # - a single int |
| 233 | # - a tuple of (int or slice) |
| 234 | if not isinstance(indices, (tuple, list)): # single item: make it a tuple |
| 235 | indices = (indices,) |
| 236 | squeeze = [axis for axis, index in enumerate(indices) if |
| 237 | isinstance(index, int)] # which axes had a single index? |
| 238 | indices = tuple( |
| 239 | index if isinstance(index, slice) else slice(index, index + 1 if index != -1 else None, 1) for index in |
| 240 | indices) # make all tuple items of type Slice |
| 241 | bs, es, ss, ds = [], [], [], [] |
| 242 | INT_MAX = 2 ** 63 - 1 |
| 243 | for axis, index in enumerate(indices): |
| 244 | if not isinstance(index, slice): |
| 245 | raise ValueError("Index expected") |
| 246 | if index.start is None and index.stop is None: # [:] can be skipped |
| 247 | continue |
| 248 | b, e, s = index.start, index.stop, index.step |
| 249 | bs.append(b if b is not None else 0) |
| 250 | es.append(e if e is not None else INT_MAX) |
| 251 | ss.append(s if s is not None else 1) |
| 252 | ds.append(axis) |
| 253 | s = _ox.slice(*self.my_args(), starts=bs, ends=es, axes=ds, steps=ss) |
| 254 | if squeeze: # single index means we must drop the axis |
| 255 | s = _ox.squeeze(*self.ox_name_args(s), axes=squeeze) |
| 256 | |
| 257 | return self.from_torch(y, s) |
| 258 | |
| 259 | def __getattribute__(self, attr): |
| 260 | """ |
| 261 | A little hack that allows to call unary operators in a chaining fashion, |
| 262 | e.g. x.shape() instead of ox.shape(x). |
| 263 | """ |
| 264 | if attr in _EagerTensor._all_ops: |
| 265 | f = _EagerTensor._all_ops[attr] |
| 266 | return functools.partial(f, self) |
| 267 | else: |
| 268 | return object.__getattribute__(self, attr) |
| 269 | |
| 270 | @classmethod |
| 271 | def ox_name_args(cls, input_names, output_names=None): |
| 272 | """ |
| 273 | generate the arguments for ONNX model builder. |
| 274 | :param input_names: input name list |
| 275 | :param output_names: output name list, can be None, or [None]*output_n |
| 276 | :return: input_names, output_names, container, operator_name |
| 277 | """ |
| 278 | container = cls.get_trace_session().container |
| 279 | if output_names is None: |
| 280 | output_names = [None] # by default, there is only one output |
| 281 | |
| 282 | output_names = [_ox.get_unique_tensor_name(str(n_)) |
| 283 | if output_names[n_] is None else |
| 284 | output_names[n_] for n_ in range(len(output_names))] |
| 285 | operator_name = None |
| 286 | return input_names, output_names, container, operator_name |
| 287 | |
| 288 | @classmethod |
| 289 | def ort_verify(cls, ts_from, ts_to): |
| 290 | result, model = cls.get_trace_session().runops(ts_from, ts_to) |
| 291 | for idx in range(len(ts_to)): |
| 292 | if not np.allclose(ts_to[idx].numpy(), result[idx]): |
| 293 | # ONNX cannot be import globally, which is conflict with torch.onnx |
| 294 | import onnx # noqa |
| 295 | onnx.save_model(model, 'mt_debmodel.onnx') |
| 296 | raise RuntimeError("ONNXRuntime Result is not same pytorch!") |
| 297 | |
| 298 | def create_and_verify(self, value, name, additional_inputs=None): |
| 299 | ts_y = self.from_torch(value, name) |
| 300 | inputs = [self] + ([] if additional_inputs is None else additional_inputs) |
| 301 | self.ort_verify(inputs, [ts_y]) |
| 302 | return ts_y |
| 303 | |
| 304 | @classmethod |
| 305 | def ox_args(cls, tensors, output_names=None): |
| 306 | input_names = [ts_ if isinstance(ts_, str) else ts_.name for ts_ in tensors] |
| 307 | return cls.ox_name_args(input_names, output_names) |
| 308 | |
| 309 | def my_args(self): |
| 310 | return self.ox_args([self]) |
| 311 | |
| 312 | @staticmethod |
| 313 | def normalize_seq(list_or_tuple): |
| 314 | return [x.value.item() if isinstance(x, _EagerTensor) else x for x in list_or_tuple] |
| 315 | |
| 316 | @staticmethod |
| 317 | def to_onnx_type(torch_type): |
| 318 | ty_dict = {torch.bool: onnx_proto.TensorProto.BOOL, |
| 319 | torch.float32: onnx_proto.TensorProto.FLOAT, |
| 320 | torch.long: onnx_proto.TensorProto.INT64, |
| 321 | torch.int32: onnx_proto.TensorProto.INT32} |
| 322 | # ... |
| 323 | return ty_dict.get(torch_type, onnx_proto.TensorProto.STRING) |
| 324 | |
| 325 | def long(self): |
| 326 | y = self._t.long() |
| 327 | s = _ox.cast(*self.my_args(), to=onnx_proto.TensorProto.INT64) |
| 328 | return self.create_and_verify(y, s[0]) |
| 329 | |
| 330 | def cumsum(self, dim: _int, *, dtype: Optional[_dtype] = None): # noqa |
| 331 | y = self._t.cumsum(dim, dtype=dtype) |
| 332 | s = _ox.cumsum(*self.my_args(), axis=dim) |
| 333 | return self.create_and_verify(y, s[0]) |
| 334 | |
| 335 | def size(self): |
| 336 | y = self._t.size() |
| 337 | s = _ox.shape(*self.my_args()) |
| 338 | return self.create_and_verify(y, s[0]) |
| 339 | |
| 340 | def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False): |
| 341 | y = self._t.type(dtype, non_blocking) |
| 342 | s = _ox.cast(*self.my_args(), to=self.to_onnx_type(dtype)) |
| 343 | return self.create_and_verify(y, s) |
| 344 | |
| 345 | def to(self, device): |
| 346 | y = self._t.to(device) |
| 347 | s = _ox.identity(*self.my_args()) |
| 348 | return self.create_and_verify(y, s[0]) |
| 349 | |
| 350 | def cpu(self): |
| 351 | y = self._t.cpu() |
| 352 | s = _ox.identity(*self.my_args()) |
| 353 | return self.create_and_verify(y, s[0]) |
| 354 | |
| 355 | def detach(self): |
| 356 | y = self._t.detach() |
| 357 | s = _ox.identity(*self.my_args()) |
| 358 | return self.create_and_verify(y, s[0]) |
| 359 | |
| 360 | def clone(self): |
| 361 | y = self._t.clone() |
| 362 | s = _ox.identity(*self.my_args()) |
| 363 | return self.create_and_verify(y, s[0]) |
| 364 | |
| 365 | def masked_fill(self, mask, value): |
| 366 | y = self._t.masked_fill(mask.value, value) |
| 367 | if not isinstance(value, _EagerTensor): |
| 368 | value = _EagerTensor.mytensor(value) |
| 369 | s = _ox.where(*_EagerTensor.ox_args([mask, value, self])) |
| 370 | return self.create_and_verify(y, s[0], additional_inputs=[mask, value]) |
| 371 | |
| 372 | def unsqueeze(self, dim: _int): |
| 373 | y = self._t.unsqueeze(dim) |
| 374 | s = _ox.unsqueeze(*self.my_args(), [dim]) |
| 375 | return self.create_and_verify(y, s[0]) |
| 376 | |
| 377 | def squeeze(self, dim: _int): |
| 378 | y = self._t.squeeze(dim) |
| 379 | s = _ox.squeeze(*self.my_args(), [dim]) |
| 380 | return self.create_and_verify(y, s[0]) |
| 381 | |
| 382 | |
| 383 | def _create_ox_sequence(*size): |
| 384 | container = _EagerTensor.get_container() |
| 385 | con_x = [] |
| 386 | if builtins.any(isinstance(n_, _EagerTensor) for n_ in size): |
| 387 | for x in size: |
| 388 | if isinstance(x, _EagerTensor): |
| 389 | x_h = _ox.unsqueeze(*_EagerTensor.ox_args([x]))[0] |
| 390 | else: |
| 391 | x_c = _ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [x]) |
| 392 | x_h = _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=x_c)[0] |
| 393 | con_x.append(x_h) |
| 394 | return _ox.concat(con_x, [_ox.get_unique_tensor_name('concat')], container, None) |
| 395 | else: |
| 396 | ts_size = _ox.make_tensor(onnx_proto.TensorProto.INT64, [len(size)], size) |
| 397 | return _ox.constant([], [_ox.get_unique_tensor_name('const')], container, None, value=ts_size) |
| 398 | |
| 399 | |
| 400 | def _create_ox_sequence_constant(*size, init_value=None, onnx_type=None): |
| 401 | if onnx_type is None: |
| 402 | onnx_type = onnx_proto.TensorProto.FLOAT |
| 403 | names = _create_ox_sequence(*size) |
| 404 | ts_val = _ox.make_tensor(onnx_type, [1], [init_value]) |
| 405 | |
| 406 | container = _EagerTensor.get_container() |
| 407 | s = _ox.constant_of_shape(names, [_ox.get_unique_tensor_name('cos')], container, None, value=ts_val) |
| 408 | return s[0] |
| 409 | |
| 410 | |
| 411 | def empty(*size: Union[_int, _EagerTensor], memory_format: Optional[memory_format] = None, out: Optional[_EagerTensor] = None, |
| 412 | dtype: _dtype = None, layout: _layout = strided, device: Union[_device, str, None] = None, |
| 413 | requires_grad: _bool = False) -> _EagerTensor: # noqa |
| 414 | |
| 415 | if len(size) == 1 and isinstance(size[0], list): |
| 416 | size = size[0] |
| 417 | n_size = _EagerTensor.normalize_seq(size) |
| 418 | y = torch.empty(*n_size, memory_format=memory_format, out=out, |
| 419 | dtype=dtype, layout=layout, device=device, requires_grad=requires_grad) |
| 420 | s = _create_ox_sequence_constant(*size, init_value=0., onnx_type=_EagerTensor.to_onnx_type(y.dtype)) |
| 421 | return _EagerTensor.from_torch(y, s) |
| 422 | |
| 423 | |
| 424 | def zeros(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided, |
| 425 | device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa |
| 426 | |
| 427 | if len(size) == 1 and isinstance(size[0], list): |
| 428 | size = size[0] |
| 429 | n_size = _EagerTensor.normalize_seq(size) |
| 430 | y = torch.zeros(*n_size, out=out, dtype=dtype, |
| 431 | layout=layout, device=device, requires_grad=requires_grad) |
| 432 | s = _create_ox_sequence_constant(*size, init_value=0, onnx_type=_EagerTensor.to_onnx_type(y.dtype)) |
| 433 | return _EagerTensor.from_torch(y, s) |
| 434 | |
| 435 | |
| 436 | def ones(*size: Union[_int, _EagerTensor], out: Optional[_EagerTensor] = None, dtype: _dtype = None, layout: _layout = strided, |
| 437 | device: Union[_device, str, None] = None, requires_grad: _bool = False) -> _EagerTensor: # noqa |
| 438 | |
| 439 | if len(size) == 1 and isinstance(size[0], list): |
| 440 | size = size[0] |
| 441 | n_size = _EagerTensor.normalize_seq(size) |
| 442 | y = torch.ones(*n_size, out=out, dtype=dtype, |
| 443 | layout=layout, device=device, requires_grad=requires_grad) |
| 444 | s = _create_ox_sequence_constant(*size, init_value=1, onnx_type=_EagerTensor.to_onnx_type(y.dtype)) |
| 445 | return _EagerTensor.from_torch(y, s) |
| 446 | |
| 447 | |
| 448 | def repeat(input_ts: _EagerTensor, *repeats: Union[_int, _EagerTensor]) -> _EagerTensor: # noqa |
| 449 | |
| 450 | if len(repeats) == 1 and isinstance(repeats[0], list): |
| 451 | repeats = repeats[0] |
| 452 | n_size = _EagerTensor.normalize_seq(repeats) |
| 453 | y = input_ts.t.repeat(*n_size) |
| 454 | seq = _create_ox_sequence(*repeats) |
| 455 | s = _ox.tile(*input_ts.my_args(), repeats=seq[0]) |
| 456 | return _EagerTensor.from_torch(y, s[0]) |
| 457 | |
| 458 | |
| 459 | def argmax(input_ts: _EagerTensor, dim: Optional[_int] = None, keepdim: _bool = False) -> _EagerTensor: # noqa |
| 460 | y = torch.argmax(input_ts.value, dim, keepdim) |
| 461 | s = _ox.argmax(*input_ts.my_args(), axis=dim, keepdims=keepdim) |
| 462 | return _EagerTensor.from_torch(y, s) |
| 463 | |
| 464 | |
| 465 | def softmax(input_ts: _EagerTensor, dim: _int, dtype: Optional[_dtype]=None) -> _EagerTensor: |
| 466 | y = torch.softmax(input_ts.value, dim, dtype) |
| 467 | s = _ox.softmax(*input_ts.my_args(), axis=dim) |
| 468 | return _EagerTensor.from_torch(y, s) |
| 469 | |
| 470 | |
| 471 | def cat(tensors: Union[Tuple[_EagerTensor, ...], List[_EagerTensor]], |
| 472 | dim, *, out: Optional[_EagerTensor] = None) -> _EagerTensor: # noqa |
| 473 | res = torch.cat([t_.value for t_ in tensors], dim, out=out) |
| 474 | oname = _ox.concat(*_EagerTensor.ox_args(tensors), dim) |
| 475 | y = _EagerTensor.from_torch(res, oname[0]) |
| 476 | _EagerTensor.ort_verify(tensors, [y]) |
| 477 | return y |
| 478 | |
| 479 | |
| 480 | def all(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa |
| 481 | container = _EagerTensor.get_container() |
| 482 | y = torch.all(input_ts.value) |
| 483 | s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64) |
| 484 | s_redm = _ox.reducemin(s_casted, [_ox.get_unique_tensor_name('reducemin')], container, None, axes=[-1]) |
| 485 | s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')], |
| 486 | container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0])) |
| 487 | s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None) |
| 488 | return input_ts.create_and_verify(y, s[0]) |
| 489 | |
| 490 | |
| 491 | def any(input_ts: _EagerTensor, out: Optional[_EagerTensor]=None) -> _EagerTensor: # noqa |
| 492 | container = _EagerTensor.get_container() |
| 493 | y = torch.any(input_ts.value) |
| 494 | s_casted = _ox.cast(*input_ts.my_args(), to=onnx_proto.TensorProto.INT64) |
| 495 | s_redm = _ox.reducesum(s_casted, [_ox.get_unique_tensor_name('reducesum')], container, None, axes=[-1]) |
| 496 | s0 = _ox.constant([], [_ox.get_unique_tensor_name('const')], |
| 497 | container, None, value=_ox.make_tensor(onnx_proto.TensorProto.INT64, [1], [0])) |
| 498 | s = _ox.greater(s_redm + s0, [_ox.get_unique_tensor_name('greater')], container, None) |
| 499 | return input_ts.create_and_verify(y, s[0]) |
| 500 | |
| 501 | |
| 502 | def reshape(input_ts: _EagerTensor, shape: _size): |
| 503 | y = input_ts.t.reshape(shape) |
| 504 | s = _ox.reshape(*input_ts.my_args(), desired_shape=shape) |
| 505 | return input_ts.create_and_verify(y, s[0]) |
| 506 | |
| 507 | |
| 508 | def transpose(input_ts: _EagerTensor, dim0: _int, dim1: _int): |
| 509 | y = input_ts.t.transpose(dim0, dim1) |
| 510 | axes = list(range(y.dim())) |
| 511 | axes[dim0], axes[dim1] = axes[dim1], axes[dim0] |
| 512 | s = _ox.transpose(*input_ts.my_args(), perm=axes) |
| 513 | return input_ts.create_and_verify(y, s[0]) |
| 514 | |
| 515 | |
| 516 | class _LoopIterator: |
| 517 | def __init__(self, ctx): |
| 518 | self.context = ctx |
| 519 | |
| 520 | def __iter__(self): |
| 521 | return self |
| 522 | |
| 523 | def __next__(self): |
| 524 | if self.context.is_stopped(): |
| 525 | _EagerTensor.get_trace_session().pop_container() |
| 526 | raise StopIteration |
| 527 | return self.context.current() |
| 528 | |
| 529 | |
| 530 | class _ControlFlowContext: |
| 531 | def __init__(self): |
| 532 | self.condition_i = None |
| 533 | self.condition = None |
| 534 | self.loop_count = None |
| 535 | self.iteration_num = None |
| 536 | self.states_i = [] |
| 537 | self.loop_states = [] |
| 538 | self.scan_outputs = [] |
| 539 | self.sub_graph = None |
| 540 | |
| 541 | def flow_output(self, cond, *outputs): |
| 542 | assert len(outputs) >= len(self.loop_states), "The loop body doesn't return enough objects" |
| 543 | if self.sub_graph is None: |
| 544 | trc = _EagerTensor.get_trace_session() |
| 545 | self.sub_graph = trc.build_graph(trc.container, |
| 546 | [self.iteration_num, self.condition] + self.loop_states, |
| 547 | [cond] + list(outputs)) |
| 548 | |
| 549 | self.condition = cond |
| 550 | c_state = len(self.loop_states) |
| 551 | self.loop_states = list(outputs[:c_state]) |
| 552 | if len(self.scan_outputs) == 0: |
| 553 | sc = [_EagerTensor(torch.unsqueeze(sci_.value, 0), 'sc_' + sci_.name) for sci_ in outputs[c_state:]] |
| 554 | self.scan_outputs = sc |
| 555 | else: |
| 556 | next_extra_vars = [] |
| 557 | for idx_, ext_ in enumerate(outputs[c_state:]): |
| 558 | et = self.scan_outputs[idx_] |
| 559 | next_extra_vars.append(_EagerTensor( |
| 560 | torch.cat([et.value, torch.unsqueeze(outputs[c_state + idx_].value, 0)]), name=et.name)) |
| 561 | self.scan_outputs = next_extra_vars |
| 562 | self.iteration_num.value.add_(1) |
| 563 | |
| 564 | def current(self): |
| 565 | return [self.iteration_num] + list(self.loop_states) |
| 566 | |
| 567 | def finalize(self): |
| 568 | # generate the outputs from the enclosing scope variables |
| 569 | full_outputs = [_EagerTensor(o_.value, 'lp_' + o_.name) for o_ in self.loop_states + self.scan_outputs] |
| 570 | _ox.loop(*_EagerTensor.ox_args( |
| 571 | [self.loop_count, self.condition_i] + list(self.states_i), |
| 572 | [ts_.name for ts_ in full_outputs]), body=self.sub_graph) |
| 573 | return tuple(full_outputs) |
| 574 | |
| 575 | def is_stopped(self): |
| 576 | return self.condition.item() is False or self.iteration_num.item() >= self.loop_count.item() |
| 577 | |
| 578 | def loop(self, loop_c, condition, *states): |
| 579 | self.condition = condition |
| 580 | self.condition_i = condition |
| 581 | self.states_i = states |
| 582 | _EagerTensor.get_trace_session().stack_container() |
| 583 | self.iteration_num = _EagerTensor.mytensor(0) |
| 584 | # clone the variables for the sub graph. |
| 585 | self.loop_states = [_EagerTensor(st_.value, st_.name) for st_ in states] |
| 586 | self.loop_count = loop_c |
| 587 | loop_b = _LoopIterator(self) |
| 588 | return iter(loop_b) |
| 589 | |
| 590 | |
| 591 | def control_flow(): |
| 592 | return _ControlFlowContext() |
| 593 | |
| 594 | |
| 595 | class _TracingEagerOp(OrtPyFunction): |
| 596 | def __call__(self, *args, **kwargs): |
| 597 | np_args = [ts_.numpy() if isinstance(ts_, _EagerTensor) else ts_ for ts_ in args] |
| 598 | outseq = super().__call__(*np_args, **kwargs) |
| 599 | outseq = outseq if isinstance(outseq, (list, tuple)) else [outseq] |
| 600 | |
| 601 | outputs = [_EagerTensor.from_onnx(outseq[n_], self.ort_session, out_.name) |
| 602 | for n_, out_ in enumerate(self.ort_session.get_outputs())] |
| 603 | |
| 604 | y_names = [y.name for y in outputs] |
| 605 | _ox.model_call(*_EagerTensor.ox_args(args, output_names=y_names), oxml=self.onnx_model) |
| 606 | return tuple(outputs) if len(outputs) > 1 else outputs[0] |
| 607 | |
| 608 | |
| 609 | def op_from_customop(op_type, *args, **kwargs) -> _TracingEagerOp: |
| 610 | return _TracingEagerOp.from_customop(op_type, *args, **kwargs) |
| 611 | |
| 612 | |
| 613 | def op_from_model(path_or_model, *args, **kwargs) -> _TracingEagerOp: |
| 614 | return _TracingEagerOp.from_model(path_or_model, *args, **kwargs) |
| 615 | |
| 616 | |
| 617 | _EagerTensor._all_ops = {'argmax': argmax, |
| 618 | 'softmax': softmax, |
| 619 | 'reshape': reshape, |
| 620 | 'transpose': transpose, |
| 621 | 'repeat': repeat, |
| 622 | 'any': any, |
| 623 | 'all': all} |
| 624 | |
| 625 | tensor = _EagerTensor.mytensor |
| 626 | tensor_from_onnx = _EagerTensor.from_onnx |
| 627 | tensor_from_torch = _EagerTensor.from_torch |
| 628 | tensor_set_session = _EagerTensor.set_active_session |
| 629 | |