microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
55e9c4965e1dcb0102960c101f7ff2c1b2384c31

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

ocos/pyfunc/pyfunc.cc

448lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include <vector>
5#include <fstream>
6#include <mutex>
7#include <complex>
8#include <memory>
9
10#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
11#define PY_ARRAY_UNIQUE_SYMBOL ocos_python_ARRAY_API
12#include <numpy/arrayobject.h>
13
14#include <pybind11/iostream.h>
15#include <pybind11/pybind11.h>
16#include <pybind11/stl.h>
17#include <pybind11/functional.h>
18#include <pybind11/numpy.h>
19#include <thread>
20#include "utils.h"
21#include "pykernel.h"
22#include "kernels/string_hash.hpp"
23
24namespace py = pybind11;
25
26const int PyCustomOpDef::undefined = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
27const int PyCustomOpDef::dt_float = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // maps to c type float
28const int PyCustomOpDef::dt_uint8 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; // maps to c type uint8_t
29const int PyCustomOpDef::dt_int8 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; // maps to c type int8_t
30const int PyCustomOpDef::dt_uint16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; // maps to c type uint16_t
31const int PyCustomOpDef::dt_int16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; // maps to c type int16_t
32const int PyCustomOpDef::dt_int32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; // maps to c type int32_t
33const int PyCustomOpDef::dt_int64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; // maps to c type int64_t
34const int PyCustomOpDef::dt_string = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; // maps to c++ type std::string
35const int PyCustomOpDef::dt_bool = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
36const int PyCustomOpDef::dt_float16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
37const int PyCustomOpDef::dt_double = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; // maps to c type double
38const int PyCustomOpDef::dt_uint32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; // maps to c type uint32_t
39const int PyCustomOpDef::dt_uint64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; // maps to c type uint64_t
40// complex with float32 real and imaginary components
41const int PyCustomOpDef::dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
42// complex with float64 real and imaginary components
43const int PyCustomOpDef::dt_complex128 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
44// Non-IEEE floating-point format based on IEEE754 single-precision
45const int PyCustomOpDef::dt_bfloat16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
46
47static int to_numpy(ONNXTensorElementDataType dt) {
48 switch (dt) {
49 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
50 return NPY_FLOAT;
51 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
52 return NPY_UINT8;
53 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
54 return NPY_INT8;
55 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
56 return NPY_UINT16;
57 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
58 return NPY_INT16;
59 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
60 return NPY_INT32;
61 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
62 return NPY_INT64;
63 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
64 return NPY_BOOL;
65 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
66 return NPY_FLOAT16;
67 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
68 return NPY_DOUBLE;
69 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
70 return NPY_UINT32;
71 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
72 return NPY_UINT64;
73 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
74 return NPY_COMPLEX64;
75 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
76 return NPY_COMPLEX128;
77 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
78 return NPY_OBJECT;
79 default:
80 throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
81 }
82}
83
84static size_t element_size(ONNXTensorElementDataType dt) {
85 switch (dt) {
86 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
87 return sizeof(float);
88 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
89 return sizeof(uint8_t);
90 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
91 return sizeof(int8_t);
92 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
93 return sizeof(uint16_t);
94 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
95 return sizeof(int16_t);
96 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
97 return sizeof(int32_t);
98 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
99 return sizeof(int64_t);
100 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
101 return sizeof(bool);
102 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
103 return sizeof(uint16_t);
104 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
105 return sizeof(double);
106 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
107 return sizeof(uint32_t);
108 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
109 return sizeof(uint64_t);
110 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
111 return sizeof(std::complex<float>);
112 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
113 return sizeof(std::complex<double>);
114 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
115 return sizeof(std::string*);
116 default:
117 throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
118 }
119}
120
121static ONNXTensorElementDataType from_numpy(int dt) {
122 switch (dt) {
123 case NPY_FLOAT:
124 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
125 case NPY_UINT8:
126 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
127 case NPY_INT8:
128 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
129 case NPY_UINT16:
130 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
131 case NPY_INT16:
132 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
133 case NPY_INT32:
134 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
135 case NPY_INT64:
136 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
137 case NPY_BOOL:
138 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
139 case NPY_FLOAT16:
140 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
141 case NPY_DOUBLE:
142 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
143 case NPY_UINT32:
144 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
145 case NPY_UINT64:
146 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
147 case NPY_COMPLEX64:
148 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
149 case NPY_COMPLEX128:
150 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
151 case NPY_OBJECT:
152 case NPY_STRING:
153 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
154 default:
155 throw std::runtime_error("No corresponding ONNX data type/Tensor data Type.");
156 }
157}
158
159struct PyCustomOpDefImpl : public PyCustomOpDef {
160 typedef std::vector<int64_t> shape_t;
161 static int64_t calc_size_from_shape(const shape_t& sp) {
162 size_t c = 1;
163 for (auto it = sp.begin(); it != sp.end(); ++it) {
164 c *= *it;
165 }
166 return c;
167 }
168
169 static py::object BuildPyObjFromTensor(const void* p, const shape_t& shape, ONNXTensorElementDataType dtype) {
170 std::vector<npy_intp> npy_dims;
171 for (auto n : shape) {
172 npy_dims.push_back(n);
173 }
174 const int numpy_type = to_numpy(dtype);
175 py::object obj = py::reinterpret_steal<py::object>(PyArray_SimpleNew(
176 static_cast<int>(shape.size()), npy_dims.data(), numpy_type));
177 void* out_ptr = static_cast<void*>(
178 PyArray_DATA(reinterpret_cast<PyArrayObject*>(obj.ptr())));
179
180 if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
181 py::object* outObj = static_cast<py::object*>(out_ptr);
182 auto size = calc_size_from_shape(shape);
183 const std::string* src = (const std::string*)p;
184 for (int i = 0; i < size; i++, src++) {
185 outObj[i] = py::cast(*src);
186 }
187 } else {
188 size_t size_type = element_size(dtype);
189 memcpy(out_ptr, p, size_type * calc_size_from_shape(shape));
190 }
191 return obj;
192 }
193
194 static py::object InvokePyFunction(uint64_t id, const py::object& feed) {
195 return (*op_invoker)(id, feed);
196 }
197
198 using callback_t = std::function<py::object(uint64_t id, const py::object&)>;
199 static std::unique_ptr<callback_t> op_invoker;
200};
201
202std::unique_ptr<PyCustomOpDefImpl::callback_t> PyCustomOpDefImpl::op_invoker;
203// static py::function g_pyfunc_caller;
204// static std::mutex op_mutex;
205// static std::condition_variable op_cv;
206// static bool is_ready = false;
207
208typedef struct {
209 const OrtValue* input_X;
210 ONNXTensorElementDataType dtype;
211 std::vector<int64_t> dimensions;
212} InputInformation;
213
214void PyCustomOpKernel::Compute(OrtKernelContext* context) {
215 // std::unique_lock<std::mutex> lck(op_mutex);
216 // is_ready = true;
217 // op_cv.notify_all();
218 // std::this_thread::sleep_for(std::chrono::milliseconds(5000));
219 size_t n_inputs = ort_.KernelContext_GetInputCount(context);
220 size_t n_outputs = ort_.KernelContext_GetOutputCount(context);
221
222 // Setup inputs
223 std::vector<InputInformation> inputs;
224 inputs.reserve(n_inputs);
225 for (size_t index = 0; index < n_inputs; ++index) {
226 const OrtValue* input_X = ort_.KernelContext_GetInput(context, index);
227 std::vector<int64_t> i_dimensions;
228 OrtTensorTypeAndShapeInfo* i_info = ort_.GetTensorTypeAndShape(input_X);
229 i_dimensions = ort_.GetTensorShape(i_info);
230 ONNXTensorElementDataType i_dtype = ort_.GetTensorElementType(i_info);
231 ort_.ReleaseTensorTypeAndShapeInfo(i_info);
232 inputs.push_back(InputInformation{input_X, i_dtype, i_dimensions});
233 }
234
235 /* Acquire GIL before calling Python code, due to it was released in sess.run */
236 py::gil_scoped_acquire acquire;
237
238 // TODO: Direct-Buffer-Access doesn't work for some reason.
239 // OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
240 // int64_t size = ort_.GetTensorShapeElementCount(output_info);
241 // ort_.ReleaseTensorTypeAndShapeInfo(output_info);
242 // py::buffer_info buf(
243 // const_cast<void *>(X), /* Pointer to buffer */
244 // sizeof(float), /* Size of one scalar */
245 // py::format_descriptor<float>::format(), /* Python struct-style format descriptor */
246 // 2, /* Number of dimensions */
247 // {2, 3}, /* Buffer dimensions */
248 // {sizeof(float) * dimensions.data()[1], /* Strides (in bytes) for each index */
249 // sizeof(float)});
250
251 {
252 py::list pyinputs;
253 for (auto it = inputs.begin(); it != inputs.end(); ++it) {
254 py::object input0 = PyCustomOpDefImpl::BuildPyObjFromTensor(
255 (const void*)ort_.GetTensorData<float>(it->input_X), it->dimensions, it->dtype);
256 pyinputs.append(input0);
257 }
258
259 // Call python function id, shape, flat coefficient.
260 py::tuple fetch = PyCustomOpDefImpl::InvokePyFunction(obj_id_, pyinputs);
261 int64_t rid = fetch[0].cast<int64_t>();
262 assert(rid == obj_id_);
263
264 // Setup output.
265 for (size_t no = 0; no < n_outputs; ++no) {
266 auto dims = fetch[1 + no * 2].cast<std::vector<int64_t>>();
267 OrtValue* output = ort_.KernelContext_GetOutput(context, no, dims.data(), dims.size());
268 OrtTensorTypeAndShapeInfo* o_info = ort_.GetTensorTypeAndShape(output);
269 ONNXTensorElementDataType o_dtype = ort_.GetTensorElementType(o_info);
270 const void* Y = (const void*)ort_.GetTensorData<float>(output);
271 ort_.ReleaseTensorTypeAndShapeInfo(o_info);
272 void* out = (void*)ort_.GetTensorMutableData<float>(output);
273
274 if (o_dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
275 auto retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
276 std::string* type_outPtr = (std::string*)out;
277 std::string* end = type_outPtr + retval.size();
278 const std::string* source = (const std::string*)retval.data();
279 for (; type_outPtr != end; ++type_outPtr, ++source) {
280 *type_outPtr = *source;
281 }
282 } else {
283 py::array retval = fetch[2 + no * 2].cast<py::array>();
284 if (element_size(o_dtype) != retval.itemsize()) {
285 switch (o_dtype) {
286 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
287 retval = fetch[2 + no * 2].cast<py::array_t<float>>();
288 break;
289 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
290 retval = fetch[2 + no * 2].cast<py::array_t<uint8_t>>();
291 break;
292 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
293 retval = fetch[2 + no * 2].cast<py::array_t<int8_t>>();
294 break;
295 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
296 retval = fetch[2 + no * 2].cast<py::array_t<uint16_t>>();
297 break;
298 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
299 retval = fetch[2 + no * 2].cast<py::array_t<int16_t>>();
300 break;
301 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
302 retval = fetch[2 + no * 2].cast<py::array_t<int32_t>>();
303 break;
304 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
305 retval = fetch[2 + no * 2].cast<py::array_t<int64_t>>();
306 break;
307 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
308 retval = fetch[2 + no * 2].cast<py::array_t<bool>>();
309 break;
310 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
311 throw std::runtime_error(MakeString(
312 "Type float16 not supported by python customops api"));
313 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
314 retval = fetch[2 + no * 2].cast<py::array_t<double>>();
315 break;
316 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
317 retval = fetch[2 + no * 2].cast<py::array_t<uint32_t>>();
318 break;
319 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
320 retval = fetch[2 + no * 2].cast<py::array_t<uint64_t>>();
321 break;
322 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
323 retval = fetch[2 + no * 2].cast<py::array_t<std::complex<float>>>();
324 break;
325 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
326 retval = fetch[2 + no * 2].cast<py::array_t<std::complex<double>>>();
327 break;
328 default:
329 throw std::runtime_error(MakeString(
330 "Type mismatch between declared output element size (",
331 element_size(o_dtype), ") and python element size (",
332 retval.itemsize(), ")"));
333 }
334 }
335 size_t size = element_size(o_dtype);
336 memcpy(out, retval.data(), size * retval.size());
337 }
338 }
339
340 py::gil_scoped_release release;
341
342 // TODO: the return value from the python callback function doesn't work in pybind11&numpy.
343 // py::gil_scoped_acquire acquire;
344 // int64_t rid = fetch[0].cast<int64_t>();
345 // assert(rid == obj_id_);
346 // size_t ntp = fetch.size() - 1;
347 // py::object ret_val = fetch[ntp];
348 // auto p2 = PyArray_FROM_O(ret_val.ptr());
349 // PyArrayObject* darray = reinterpret_cast<PyArrayObject*>(p2);
350 // std::vector<int64_t> dims;
351 // const int npy_type = PyArray_TYPE(darray);
352 // {
353 // int ndim = PyArray_NDIM(darray);
354 // const npy_intp* npy_dims = PyArray_DIMS(darray);
355 // dims.resize(ndim);
356 // std::copy(npy_dims, npy_dims+ndim, dims.begin());
357 // }
358
359 // auto element_type = PyCustomOpDefImpl::from_numpy(npy_type);
360 // OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dims.data(), dims.size());
361 // float* out = ort_.GetTensorMutableData<float>(output);
362 // const void* pyOut = PyData(darray, dd);
363
364 // memcpy(out, pyOut, PyArray_NBYTES((PyArrayObject*)NULL));
365 }
366}
367
368std::vector<PyCustomOpFactory>& PyCustomOpDef_python_operator_list() {
369 static std::vector<PyCustomOpFactory> lst_custom_opdef;
370 return lst_custom_opdef;
371}
372
373void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
374 // No need to protect against concurrent access, GIL is doing that.
375 PyCustomOpDef_python_operator_list().push_back(PyCustomOpFactory(cod));
376}
377
378const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count) {
379 // The result must stay alive
380 std::vector<PyCustomOpFactory>& copy = PyCustomOpDef_python_operator_list();
381 if (count < copy.size())
382 return &(copy[count]);
383 return nullptr;
384}
385
386const OrtCustomOp* FetchPyCustomOps(size_t& count) {
387 auto ptr = PyCustomOpDef_FetchPyCustomOps(count);
388 if (ptr == nullptr)
389 return nullptr;
390 return ptr;
391}
392
393// static std::ofstream logger;
394static int init_numpy() {
395 import_array();
396 // logger.open("./ggtest.log.txt", std::ofstream::out | std::ofstream::app);
397 // logger << "first line." << std::endl;
398 return 0;
399}
400
401uint64_t hash_64(const std::string& str, uint64_t num_buckets, bool fast) {
402 if (fast) {
403 return Hash64Fast(str.c_str(), str.size()) % static_cast<uint64_t>(num_buckets);
404 }
405 return Hash64(str.c_str(), str.size()) % static_cast<uint64_t>(num_buckets);
406}
407
408void AddGlobalMethods(pybind11::module& m) {
409 m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); });
410 m.def("hash_64", &hash_64, "Computes a uint64 hash for a string (from tensorflow).");
411}
412
413void AddObjectMethods(pybind11::module& m) {
414 pybind11::class_<PyCustomOpDef>(m, "PyCustomOpDef")
415 .def(pybind11::init<>())
416 .def_readwrite("op_type", &PyCustomOpDef::op_type)
417 .def_readwrite("obj_id", &PyCustomOpDef::obj_id)
418 .def_readwrite("input_types", &PyCustomOpDef::input_types)
419 .def_readwrite("output_types", &PyCustomOpDef::output_types)
420 .def_static("install_hooker", [](py::object obj) {
421 PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
422 .def_readonly_static("undefined", &PyCustomOpDef::undefined)
423 .def_readonly_static("dt_float", &PyCustomOpDef::dt_float)
424 .def_readonly_static("dt_uint8", &PyCustomOpDef::dt_uint8)
425 .def_readonly_static("dt_int8", &PyCustomOpDef::dt_int8)
426 .def_readonly_static("dt_uint16", &PyCustomOpDef::dt_uint16)
427 .def_readonly_static("dt_int16", &PyCustomOpDef::dt_int16)
428 .def_readonly_static("dt_int32", &PyCustomOpDef::dt_int32)
429 .def_readonly_static("dt_int64", &PyCustomOpDef::dt_int64)
430 .def_readonly_static("dt_string", &PyCustomOpDef::dt_string)
431 .def_readonly_static("dt_bool", &PyCustomOpDef::dt_bool)
432 .def_readonly_static("dt_float16", &PyCustomOpDef::dt_float16)
433 .def_readonly_static("dt_double", &PyCustomOpDef::dt_double)
434 .def_readonly_static("dt_uint32", &PyCustomOpDef::dt_uint32)
435 .def_readonly_static("dt_uint64", &PyCustomOpDef::dt_uint64)
436 .def_readonly_static("dt_complex64", &PyCustomOpDef::dt_complex64)
437 .def_readonly_static("dt_complex128", &PyCustomOpDef::dt_complex128)
438 .def_readonly_static("dt_bfloat16", &PyCustomOpDef::dt_bfloat16);
439}
440
441PYBIND11_MODULE(_ortcustomops, m) {
442 m.doc() = "pybind11 stateful interface to ORT Custom Ops library";
443 //TODO: RegisterExceptions(m);
444
445 init_numpy();
446 AddGlobalMethods(m);
447 AddObjectMethods(m);
448}
449