microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b3a300d7bf6f3e99fb907e682f7246ed5ed5805e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

ocos/pyfunc/pyfunc.cc

466lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include <vector>
5#include <fstream>
6#include <mutex>
7#include <complex>
8#include <memory>
9
10#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
11#define PY_ARRAY_UNIQUE_SYMBOL ocos_python_ARRAY_API
12#include <numpy/arrayobject.h>
13
14#include <pybind11/iostream.h>
15#include <pybind11/pybind11.h>
16#include <pybind11/stl.h>
17#include <pybind11/functional.h>
18#include <pybind11/numpy.h>
19#include <thread>
20#include "utils.h"
21#include "pykernel.h"
22#include "kernels/string_hash.hpp"
23#include "kernels/string_common.h"
24
25namespace py = pybind11;
26
27const int PyCustomOpDef::undefined = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
28const int PyCustomOpDef::dt_float = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // maps to c type float
29const int PyCustomOpDef::dt_uint8 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; // maps to c type uint8_t
30const int PyCustomOpDef::dt_int8 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; // maps to c type int8_t
31const int PyCustomOpDef::dt_uint16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; // maps to c type uint16_t
32const int PyCustomOpDef::dt_int16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; // maps to c type int16_t
33const int PyCustomOpDef::dt_int32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; // maps to c type int32_t
34const int PyCustomOpDef::dt_int64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; // maps to c type int64_t
35const int PyCustomOpDef::dt_string = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; // maps to c++ type std::string
36const int PyCustomOpDef::dt_bool = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
37const int PyCustomOpDef::dt_float16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
38const int PyCustomOpDef::dt_double = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; // maps to c type double
39const int PyCustomOpDef::dt_uint32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; // maps to c type uint32_t
40const int PyCustomOpDef::dt_uint64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; // maps to c type uint64_t
41// complex with float32 real and imaginary components
42const int PyCustomOpDef::dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
43// complex with float64 real and imaginary components
44const int PyCustomOpDef::dt_complex128 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
45// Non-IEEE floating-point format based on IEEE754 single-precision
46const int PyCustomOpDef::dt_bfloat16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
47
48static int to_numpy(ONNXTensorElementDataType dt) {
49 switch (dt) {
50 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
51 return NPY_FLOAT;
52 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
53 return NPY_UINT8;
54 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
55 return NPY_INT8;
56 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
57 return NPY_UINT16;
58 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
59 return NPY_INT16;
60 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
61 return NPY_INT32;
62 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
63 return NPY_INT64;
64 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
65 return NPY_BOOL;
66 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
67 return NPY_FLOAT16;
68 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
69 return NPY_DOUBLE;
70 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
71 return NPY_UINT32;
72 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
73 return NPY_UINT64;
74 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
75 return NPY_COMPLEX64;
76 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
77 return NPY_COMPLEX128;
78 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
79 return NPY_OBJECT;
80 default:
81 throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
82 }
83}
84
85static size_t element_size(ONNXTensorElementDataType dt) {
86 switch (dt) {
87 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
88 return sizeof(float);
89 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
90 return sizeof(uint8_t);
91 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
92 return sizeof(int8_t);
93 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
94 return sizeof(uint16_t);
95 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
96 return sizeof(int16_t);
97 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
98 return sizeof(int32_t);
99 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
100 return sizeof(int64_t);
101 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
102 return sizeof(bool);
103 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
104 return sizeof(uint16_t);
105 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
106 return sizeof(double);
107 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
108 return sizeof(uint32_t);
109 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
110 return sizeof(uint64_t);
111 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
112 return sizeof(std::complex<float>);
113 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
114 return sizeof(std::complex<double>);
115 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
116 throw std::runtime_error("OrtValue content cannot be casted into std::string*.");
117 default:
118 throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
119 }
120}
121
122static ONNXTensorElementDataType from_numpy(int dt) {
123 switch (dt) {
124 case NPY_FLOAT:
125 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
126 case NPY_UINT8:
127 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
128 case NPY_INT8:
129 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
130 case NPY_UINT16:
131 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
132 case NPY_INT16:
133 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
134 case NPY_INT32:
135 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
136 case NPY_INT64:
137 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
138 case NPY_BOOL:
139 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
140 case NPY_FLOAT16:
141 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
142 case NPY_DOUBLE:
143 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
144 case NPY_UINT32:
145 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
146 case NPY_UINT64:
147 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
148 case NPY_COMPLEX64:
149 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
150 case NPY_COMPLEX128:
151 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
152 case NPY_OBJECT:
153 case NPY_STRING:
154 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
155 default:
156 throw std::runtime_error("No corresponding ONNX data type/Tensor data Type.");
157 }
158}
159
160struct PyCustomOpDefImpl : public PyCustomOpDef {
161 typedef std::vector<int64_t> shape_t;
162 static int64_t calc_size_from_shape(const shape_t& sp) {
163 size_t c = 1;
164 for (auto it = sp.begin(); it != sp.end(); ++it) {
165 c *= *it;
166 }
167 return c;
168 }
169
170 static py::object BuildPyObjFromTensor(
171 OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
172 const shape_t& shape, ONNXTensorElementDataType dtype) {
173 std::vector<npy_intp> npy_dims;
174 for (auto n : shape) {
175 npy_dims.push_back(n);
176 }
177 const int numpy_type = to_numpy(dtype);
178 py::object obj = py::reinterpret_steal<py::object>(PyArray_SimpleNew(
179 static_cast<int>(shape.size()), npy_dims.data(), numpy_type));
180 void* out_ptr = static_cast<void*>(
181 PyArray_DATA(reinterpret_cast<PyArrayObject*>(obj.ptr())));
182
183 if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
184 py::object* outObj = static_cast<py::object*>(out_ptr);
185 auto size = calc_size_from_shape(shape);
186 std::vector<std::string> src;
187 GetTensorMutableDataString(api, ort, context, value, src);
188 for (int i = 0; i < size; ++i) {
189 outObj[i] = py::cast(src[i]);
190 }
191 } else {
192 const void* p = (const void*)ort.GetTensorData<char>(value);
193 size_t size_type = element_size(dtype);
194 memcpy(out_ptr, p, size_type * calc_size_from_shape(shape));
195 }
196 return obj;
197 }
198
199 static py::object InvokePyFunction(uint64_t id, const py::object& feed, const py::object& attrs) {
200 return (*op_invoker)(id, feed, attrs);
201 }
202
203 using callback_t = std::function<py::object(uint64_t id, const py::object&, const py::object&)>;
204 static std::unique_ptr<callback_t> op_invoker;
205};
206
207std::unique_ptr<PyCustomOpDefImpl::callback_t> PyCustomOpDefImpl::op_invoker;
208// static py::function g_pyfunc_caller;
209// static std::mutex op_mutex;
210// static std::condition_variable op_cv;
211// static bool is_ready = false;
212
213typedef struct {
214 const OrtValue* input_X;
215 ONNXTensorElementDataType dtype;
216 std::vector<int64_t> dimensions;
217} InputInformation;
218
219void PyCustomOpKernel::Compute(OrtKernelContext* context) {
220 // std::unique_lock<std::mutex> lck(op_mutex);
221 // is_ready = true;
222 // op_cv.notify_all();
223 // std::this_thread::sleep_for(std::chrono::milliseconds(5000));
224 size_t n_inputs = ort_.KernelContext_GetInputCount(context);
225 size_t n_outputs = ort_.KernelContext_GetOutputCount(context);
226
227 // Setup inputs
228 std::vector<InputInformation> inputs;
229 inputs.reserve(n_inputs);
230 for (size_t index = 0; index < n_inputs; ++index) {
231 const OrtValue* input_X = ort_.KernelContext_GetInput(context, index);
232 std::vector<int64_t> i_dimensions;
233 OrtTensorTypeAndShapeInfo* i_info = ort_.GetTensorTypeAndShape(input_X);
234 i_dimensions = ort_.GetTensorShape(i_info);
235 ONNXTensorElementDataType i_dtype = ort_.GetTensorElementType(i_info);
236 ort_.ReleaseTensorTypeAndShapeInfo(i_info);
237 inputs.push_back(InputInformation{input_X, i_dtype, i_dimensions});
238 }
239
240 /* Acquire GIL before calling Python code, due to it was released in sess.run */
241 py::gil_scoped_acquire acquire;
242
243 // TODO: Direct-Buffer-Access doesn't work for some reason.
244 // OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
245 // int64_t size = ort_.GetTensorShapeElementCount(output_info);
246 // ort_.ReleaseTensorTypeAndShapeInfo(output_info);
247 // py::buffer_info buf(
248 // const_cast<void *>(X), /* Pointer to buffer */
249 // sizeof(float), /* Size of one scalar */
250 // py::format_descriptor<float>::format(), /* Python struct-style format descriptor */
251 // 2, /* Number of dimensions */
252 // {2, 3}, /* Buffer dimensions */
253 // {sizeof(float) * dimensions.data()[1], /* Strides (in bytes) for each index */
254 // sizeof(float)});
255
256 {
257 py::list pyinputs;
258 for (auto it = inputs.begin(); it != inputs.end(); ++it) {
259 py::object input0 = PyCustomOpDefImpl::BuildPyObjFromTensor(
260 api_, ort_, context, it->input_X, it->dimensions, it->dtype);
261 pyinputs.append(input0);
262 }
263
264 py::dict pyattrs;
265 for (auto it = attrs_values_.begin(); it != attrs_values_.end(); ++it) {
266 pyattrs[py::str(it->first)] = py::str(it->second);
267 }
268
269 // Call python function id, shape, flat coefficient.
270 py::tuple fetch = PyCustomOpDefImpl::InvokePyFunction(obj_id_, pyinputs, pyattrs);
271 int64_t rid = fetch[0].cast<int64_t>();
272 assert(rid == obj_id_);
273
274 // Setup output.
275 for (size_t no = 0; no < n_outputs; ++no) {
276 auto dims = fetch[1 + no * 2].cast<std::vector<int64_t>>();
277 OrtValue* output = ort_.KernelContext_GetOutput(context, no, dims.data(), dims.size());
278 OrtTensorTypeAndShapeInfo* o_info = ort_.GetTensorTypeAndShape(output);
279 ONNXTensorElementDataType o_dtype = ort_.GetTensorElementType(o_info);
280 ort_.ReleaseTensorTypeAndShapeInfo(o_info);
281
282 if (o_dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
283 std::vector<std::string> retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
284 FillTensorDataString(api_, ort_, context, retval, output);
285 } else {
286 const void* Y = (const void*)ort_.GetTensorData<float>(output);
287 void* out = (void*)ort_.GetTensorMutableData<float>(output);
288 py::array retval = fetch[2 + no * 2].cast<py::array>();
289 if (element_size(o_dtype) != retval.itemsize()) {
290 switch (o_dtype) {
291 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
292 retval = fetch[2 + no * 2].cast<py::array_t<float>>();
293 break;
294 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
295 retval = fetch[2 + no * 2].cast<py::array_t<uint8_t>>();
296 break;
297 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
298 retval = fetch[2 + no * 2].cast<py::array_t<int8_t>>();
299 break;
300 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
301 retval = fetch[2 + no * 2].cast<py::array_t<uint16_t>>();
302 break;
303 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
304 retval = fetch[2 + no * 2].cast<py::array_t<int16_t>>();
305 break;
306 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
307 retval = fetch[2 + no * 2].cast<py::array_t<int32_t>>();
308 break;
309 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
310 retval = fetch[2 + no * 2].cast<py::array_t<int64_t>>();
311 break;
312 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
313 retval = fetch[2 + no * 2].cast<py::array_t<bool>>();
314 break;
315 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
316 throw std::runtime_error(MakeString(
317 "Type float16 not supported by python customops api"));
318 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
319 retval = fetch[2 + no * 2].cast<py::array_t<double>>();
320 break;
321 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
322 retval = fetch[2 + no * 2].cast<py::array_t<uint32_t>>();
323 break;
324 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
325 retval = fetch[2 + no * 2].cast<py::array_t<uint64_t>>();
326 break;
327 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
328 retval = fetch[2 + no * 2].cast<py::array_t<std::complex<float>>>();
329 break;
330 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
331 retval = fetch[2 + no * 2].cast<py::array_t<std::complex<double>>>();
332 break;
333 default:
334 throw std::runtime_error(MakeString(
335 "Type mismatch between declared output element size (",
336 element_size(o_dtype), ") and python element size (",
337 retval.itemsize(), ")"));
338 }
339 }
340 size_t size = element_size(o_dtype);
341 memcpy(out, retval.data(), size * retval.size());
342 }
343 }
344
345 py::gil_scoped_release release;
346
347 // TODO: the return value from the python callback function doesn't work in pybind11&numpy.
348 // py::gil_scoped_acquire acquire;
349 // int64_t rid = fetch[0].cast<int64_t>();
350 // assert(rid == obj_id_);
351 // size_t ntp = fetch.size() - 1;
352 // py::object ret_val = fetch[ntp];
353 // auto p2 = PyArray_FROM_O(ret_val.ptr());
354 // PyArrayObject* darray = reinterpret_cast<PyArrayObject*>(p2);
355 // std::vector<int64_t> dims;
356 // const int npy_type = PyArray_TYPE(darray);
357 // {
358 // int ndim = PyArray_NDIM(darray);
359 // const npy_intp* npy_dims = PyArray_DIMS(darray);
360 // dims.resize(ndim);
361 // std::copy(npy_dims, npy_dims+ndim, dims.begin());
362 // }
363
364 // auto element_type = PyCustomOpDefImpl::from_numpy(npy_type);
365 // OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dims.data(), dims.size());
366 // float* out = ort_.GetTensorMutableData<float>(output);
367 // const void* pyOut = PyData(darray, dd);
368
369 // memcpy(out, pyOut, PyArray_NBYTES((PyArrayObject*)NULL));
370 }
371}
372
373std::vector<PyCustomOpFactory>& PyCustomOpDef_python_operator_list() {
374 static std::vector<PyCustomOpFactory> lst_custom_opdef;
375 return lst_custom_opdef;
376}
377
378void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
379 // No need to protect against concurrent access, GIL is doing that.
380 PyCustomOpDef_python_operator_list().push_back(PyCustomOpFactory(cod));
381}
382
383const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count) {
384 if (!EnablePyCustomOps(true)) {
385 EnablePyCustomOps(false);
386 return nullptr;
387 }
388
389 // The result must stay alive
390 std::vector<PyCustomOpFactory>& copy = PyCustomOpDef_python_operator_list();
391 if (count < copy.size())
392 return &(copy[count]);
393 return nullptr;
394}
395
396const OrtCustomOp* FetchPyCustomOps(size_t& count) {
397 auto ptr = PyCustomOpDef_FetchPyCustomOps(count);
398 if (ptr == nullptr)
399 return nullptr;
400 return ptr;
401}
402
403bool EnablePyCustomOps(bool enabled) {
404 static bool f_pyop_enabled = true;
405 bool last = f_pyop_enabled;
406 f_pyop_enabled = enabled;
407 return last;
408}
409
410// static std::ofstream logger;
411static int init_numpy() {
412 import_array();
413 // logger.open("./ggtest.log.txt", std::ofstream::out | std::ofstream::app);
414 // logger << "first line." << std::endl;
415 return 0;
416}
417
418uint64_t hash_64(const std::string& str, uint64_t num_buckets, bool fast) {
419 if (fast) {
420 return Hash64Fast(str.c_str(), str.size()) % static_cast<uint64_t>(num_buckets);
421 }
422 return Hash64(str.c_str(), str.size()) % static_cast<uint64_t>(num_buckets);
423}
424
425void AddGlobalMethods(pybind11::module& m) {
426 m.def("enable_custom_op", &EnablePyCustomOps, "Enable or disable pyop functions.");
427 m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); });
428 m.def("hash_64", &hash_64, "Computes a uint64 hash for a string (from tensorflow).");
429}
430
431void AddObjectMethods(pybind11::module& m) {
432 pybind11::class_<PyCustomOpDef>(m, "PyCustomOpDef")
433 .def(pybind11::init<>())
434 .def_readwrite("op_type", &PyCustomOpDef::op_type)
435 .def_readwrite("obj_id", &PyCustomOpDef::obj_id)
436 .def_readwrite("input_types", &PyCustomOpDef::input_types)
437 .def_readwrite("output_types", &PyCustomOpDef::output_types)
438 .def_readwrite("attrs", &PyCustomOpDef::attrs)
439 .def_static("install_hooker", [](py::object obj) { PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
440 .def_readonly_static("undefined", &PyCustomOpDef::undefined)
441 .def_readonly_static("dt_float", &PyCustomOpDef::dt_float)
442 .def_readonly_static("dt_uint8", &PyCustomOpDef::dt_uint8)
443 .def_readonly_static("dt_int8", &PyCustomOpDef::dt_int8)
444 .def_readonly_static("dt_uint16", &PyCustomOpDef::dt_uint16)
445 .def_readonly_static("dt_int16", &PyCustomOpDef::dt_int16)
446 .def_readonly_static("dt_int32", &PyCustomOpDef::dt_int32)
447 .def_readonly_static("dt_int64", &PyCustomOpDef::dt_int64)
448 .def_readonly_static("dt_string", &PyCustomOpDef::dt_string)
449 .def_readonly_static("dt_bool", &PyCustomOpDef::dt_bool)
450 .def_readonly_static("dt_float16", &PyCustomOpDef::dt_float16)
451 .def_readonly_static("dt_double", &PyCustomOpDef::dt_double)
452 .def_readonly_static("dt_uint32", &PyCustomOpDef::dt_uint32)
453 .def_readonly_static("dt_uint64", &PyCustomOpDef::dt_uint64)
454 .def_readonly_static("dt_complex64", &PyCustomOpDef::dt_complex64)
455 .def_readonly_static("dt_complex128", &PyCustomOpDef::dt_complex128)
456 .def_readonly_static("dt_bfloat16", &PyCustomOpDef::dt_bfloat16);
457}
458
459PYBIND11_MODULE(_ortcustomops, m) {
460 m.doc() = "pybind11 stateful interface to ORT Custom Ops library";
461 //TODO: RegisterExceptions(m);
462
463 init_numpy();
464 AddGlobalMethods(m);
465 AddObjectMethods(m);
466}
467