microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9eef22cb81d762f6c093a4740c992582267a783f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

pyop/pyfunc.cc

495lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include <fstream>
5#include <mutex>
6#include <complex>
7#include <memory>
8
9#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
10#define PY_ARRAY_UNIQUE_SYMBOL ocos_python_ARRAY_API
11#include <numpy/arrayobject.h>
12
13#include <pybind11/iostream.h>
14#include <pybind11/pybind11.h>
15#include <pybind11/stl.h>
16#include <pybind11/functional.h>
17#include <pybind11/numpy.h>
18#include <thread>
19#include "string_utils.h"
20#include "string_tensor.h"
21#include "pykernel.h"
22
23
24namespace py = pybind11;
25
26const int PyCustomOpDef::undefined = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
27const int PyCustomOpDef::dt_float = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; // maps to c type float
28const int PyCustomOpDef::dt_uint8 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; // maps to c type uint8_t
29const int PyCustomOpDef::dt_int8 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; // maps to c type int8_t
30const int PyCustomOpDef::dt_uint16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; // maps to c type uint16_t
31const int PyCustomOpDef::dt_int16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; // maps to c type int16_t
32const int PyCustomOpDef::dt_int32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; // maps to c type int32_t
33const int PyCustomOpDef::dt_int64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; // maps to c type int64_t
34const int PyCustomOpDef::dt_string = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; // maps to c++ type std::string
35const int PyCustomOpDef::dt_bool = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
36const int PyCustomOpDef::dt_float16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
37const int PyCustomOpDef::dt_double = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; // maps to c type double
38const int PyCustomOpDef::dt_uint32 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; // maps to c type uint32_t
39const int PyCustomOpDef::dt_uint64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; // maps to c type uint64_t
40// complex with float32 real and imaginary components
41const int PyCustomOpDef::dt_complex64 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
42// complex with float64 real and imaginary components
43const int PyCustomOpDef::dt_complex128 = ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
44// Non-IEEE floating-point format based on IEEE754 single-precision
45const int PyCustomOpDef::dt_bfloat16 = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
46
47static int to_numpy(ONNXTensorElementDataType dt) {
48 switch (dt) {
49 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
50 return NPY_FLOAT;
51 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
52 return NPY_UINT8;
53 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
54 return NPY_INT8;
55 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
56 return NPY_UINT16;
57 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
58 return NPY_INT16;
59 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
60 return NPY_INT32;
61 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
62 return NPY_INT64;
63 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
64 return NPY_BOOL;
65 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
66 return NPY_FLOAT16;
67 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
68 return NPY_DOUBLE;
69 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
70 return NPY_UINT32;
71 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
72 return NPY_UINT64;
73 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
74 return NPY_COMPLEX64;
75 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
76 return NPY_COMPLEX128;
77 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
78 return NPY_OBJECT;
79 default:
80 throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
81 }
82}
83
84static size_t element_size(ONNXTensorElementDataType dt) {
85 switch (dt) {
86 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
87 return sizeof(float);
88 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
89 return sizeof(uint8_t);
90 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
91 return sizeof(int8_t);
92 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
93 return sizeof(uint16_t);
94 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
95 return sizeof(int16_t);
96 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
97 return sizeof(int32_t);
98 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
99 return sizeof(int64_t);
100 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
101 return sizeof(bool);
102 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
103 return sizeof(uint16_t);
104 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
105 return sizeof(double);
106 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
107 return sizeof(uint32_t);
108 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
109 return sizeof(uint64_t);
110 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
111 return sizeof(std::complex<float>);
112 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
113 return sizeof(std::complex<double>);
114 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
115 throw std::runtime_error("OrtValue content cannot be casted into std::string*.");
116 default:
117 throw std::runtime_error("No corresponding Numpy data type/Tensor data Type.");
118 }
119}
120
121static ONNXTensorElementDataType from_numpy(int dt) {
122 switch (dt) {
123 case NPY_FLOAT:
124 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
125 case NPY_UINT8:
126 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
127 case NPY_INT8:
128 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
129 case NPY_UINT16:
130 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
131 case NPY_INT16:
132 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
133 case NPY_INT32:
134 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
135 case NPY_INT64:
136 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
137 case NPY_BOOL:
138 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
139 case NPY_FLOAT16:
140 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
141 case NPY_DOUBLE:
142 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
143 case NPY_UINT32:
144 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
145 case NPY_UINT64:
146 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
147 case NPY_COMPLEX64:
148 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64;
149 case NPY_COMPLEX128:
150 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128;
151 case NPY_OBJECT:
152 case NPY_STRING:
153 return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
154 default:
155 throw std::runtime_error("No corresponding ONNX data type/Tensor data Type.");
156 }
157}
158
159struct PyCustomOpDefImpl : public PyCustomOpDef {
160 typedef std::vector<int64_t> shape_t;
161 static int64_t calc_size_from_shape(const shape_t& sp) {
162 size_t c = 1;
163 for (auto it = sp.begin(); it != sp.end(); ++it) {
164 c *= *it;
165 }
166 return c;
167 }
168
169 static py::object BuildPyObjFromTensor(
170 OrtApi& api, Ort::CustomOpApi& ort, OrtKernelContext* context, const OrtValue* value,
171 const shape_t& shape, ONNXTensorElementDataType dtype) {
172 std::vector<npy_intp> npy_dims;
173 for (auto n : shape) {
174 npy_dims.push_back(n);
175 }
176 const int numpy_type = to_numpy(dtype);
177 py::object obj = py::reinterpret_steal<py::object>(PyArray_SimpleNew(
178 static_cast<int>(shape.size()), npy_dims.data(), numpy_type));
179 void* out_ptr = static_cast<void*>(
180 PyArray_DATA(reinterpret_cast<PyArrayObject*>(obj.ptr())));
181
182 if (dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
183 py::object* outObj = static_cast<py::object*>(out_ptr);
184 auto size = calc_size_from_shape(shape);
185 std::vector<std::string> src;
186 GetTensorMutableDataString(api, ort, context, value, src);
187 for (int i = 0; i < size; ++i) {
188 outObj[i] = py::cast(src[i]);
189 }
190 } else {
191 const void* p = (const void*)ort.GetTensorData<char>(value);
192 size_t size_type = element_size(dtype);
193 memcpy(out_ptr, p, size_type * calc_size_from_shape(shape));
194 }
195 return obj;
196 }
197
198 static py::object InvokePyFunction(uint64_t id, const py::object& feed, const py::object& attrs) {
199 return (*op_invoker)(id, feed, attrs);
200 }
201
202 using callback_t = std::function<py::object(uint64_t id, const py::object&, const py::object&)>;
203 static std::unique_ptr<callback_t> op_invoker;
204};
205
206std::unique_ptr<PyCustomOpDefImpl::callback_t> PyCustomOpDefImpl::op_invoker;
207typedef struct {
208 const OrtValue* input_X;
209 ONNXTensorElementDataType dtype;
210 std::vector<int64_t> dimensions;
211} InputInformation;
212
213PyCustomOpKernel::PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info,
214 uint64_t id, const std::vector<std::string>& attrs)
215 : api_(api),
216 ort_(api_),
217 obj_id_(id) {
218 size_t size;
219 for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
220 size = 0;
221 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, it->c_str(), nullptr, &size);
222 if ((status != nullptr) && api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
223 std::string error_message(api_.GetErrorMessage(status));
224 api_.ReleaseStatus(status);
225 throw std::runtime_error(MakeString(
226 "Unable to find attribute '", *it, "' due to '",
227 error_message, "'."));
228 }
229 if (status != nullptr) {
230 api_.ReleaseStatus(status);
231 }
232 attrs_values_[*it] = "";
233 attrs_values_[*it].resize(size);
234 status = api_.KernelInfoGetAttribute_string(info, it->c_str(), &(attrs_values_[*it][0]), &size);
235 if ((status != nullptr) && (api_.GetErrorCode(status) != ORT_OK)) {
236 api_.ReleaseStatus(status);
237 throw std::runtime_error(MakeString(
238 "Unable to retrieve attribute '", *it, "' due to '",
239 api_.GetErrorMessage(status), "'."));
240 }
241 attrs_values_[*it].resize(size - 1);
242 if (status != nullptr) {
243 api_.ReleaseStatus(status);
244 }
245 }
246}
247
248void PyCustomOpKernel::Compute(OrtKernelContext* context) {
249 size_t n_inputs = ort_.KernelContext_GetInputCount(context);
250 size_t n_outputs = ort_.KernelContext_GetOutputCount(context);
251
252 // Setup inputs
253 std::vector<InputInformation> inputs;
254 inputs.reserve(n_inputs);
255 for (size_t index = 0; index < n_inputs; ++index) {
256 const OrtValue* input_X = ort_.KernelContext_GetInput(context, index);
257 std::vector<int64_t> i_dimensions;
258 OrtTensorTypeAndShapeInfo* i_info = ort_.GetTensorTypeAndShape(input_X);
259 i_dimensions = ort_.GetTensorShape(i_info);
260 ONNXTensorElementDataType i_dtype = ort_.GetTensorElementType(i_info);
261 ort_.ReleaseTensorTypeAndShapeInfo(i_info);
262 inputs.push_back(InputInformation{input_X, i_dtype, i_dimensions});
263 }
264
265 /* Acquire GIL before calling Python code, due to it was released in sess.run */
266 py::gil_scoped_acquire acquire;
267
268 {
269 py::list pyinputs;
270 for (auto it = inputs.begin(); it != inputs.end(); ++it) {
271 py::object input0 = PyCustomOpDefImpl::BuildPyObjFromTensor(
272 api_, ort_, context, it->input_X, it->dimensions, it->dtype);
273 pyinputs.append(input0);
274 }
275
276 py::dict pyattrs;
277 for (auto it = attrs_values_.begin(); it != attrs_values_.end(); ++it) {
278 pyattrs[py::str(it->first)] = py::str(it->second);
279 }
280
281 // Call python function id, shape, flat coefficient.
282 py::tuple fetch = PyCustomOpDefImpl::InvokePyFunction(obj_id_, pyinputs, pyattrs);
283 int64_t rid = fetch[0].cast<int64_t>();
284 assert(rid == obj_id_);
285
286 // Setup output.
287 for (size_t no = 0; no < n_outputs; ++no) {
288 auto dims = fetch[1 + no * 2].cast<std::vector<int64_t>>();
289 OrtValue* output = ort_.KernelContext_GetOutput(context, no, dims.data(), dims.size());
290 OrtTensorTypeAndShapeInfo* o_info = ort_.GetTensorTypeAndShape(output);
291 ONNXTensorElementDataType o_dtype = ort_.GetTensorElementType(o_info);
292 ort_.ReleaseTensorTypeAndShapeInfo(o_info);
293
294 if (o_dtype == ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
295 std::vector<std::string> retval = fetch[2 + no * 2].cast<std::vector<std::string>>();
296 FillTensorDataString(api_, ort_, context, retval, output);
297 } else {
298 const void* Y = (const void*)ort_.GetTensorData<float>(output);
299 void* out = (void*)ort_.GetTensorMutableData<float>(output);
300 py::array retval = fetch[2 + no * 2].cast<py::array>();
301 if (element_size(o_dtype) != retval.itemsize()) {
302 switch (o_dtype) {
303 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
304 retval = fetch[2 + no * 2].cast<py::array_t<float>>();
305 break;
306 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
307 retval = fetch[2 + no * 2].cast<py::array_t<uint8_t>>();
308 break;
309 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
310 retval = fetch[2 + no * 2].cast<py::array_t<int8_t>>();
311 break;
312 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
313 retval = fetch[2 + no * 2].cast<py::array_t<uint16_t>>();
314 break;
315 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
316 retval = fetch[2 + no * 2].cast<py::array_t<int16_t>>();
317 break;
318 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
319 retval = fetch[2 + no * 2].cast<py::array_t<int32_t>>();
320 break;
321 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
322 retval = fetch[2 + no * 2].cast<py::array_t<int64_t>>();
323 break;
324 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
325 retval = fetch[2 + no * 2].cast<py::array_t<bool>>();
326 break;
327 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
328 throw std::runtime_error(MakeString(
329 "Type float16 not supported by python customops api"));
330 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
331 retval = fetch[2 + no * 2].cast<py::array_t<double>>();
332 break;
333 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
334 retval = fetch[2 + no * 2].cast<py::array_t<uint32_t>>();
335 break;
336 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
337 retval = fetch[2 + no * 2].cast<py::array_t<uint64_t>>();
338 break;
339 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64:
340 retval = fetch[2 + no * 2].cast<py::array_t<std::complex<float>>>();
341 break;
342 case ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128:
343 retval = fetch[2 + no * 2].cast<py::array_t<std::complex<double>>>();
344 break;
345 default:
346 throw std::runtime_error(MakeString(
347 "Type mismatch between declared output element size (",
348 element_size(o_dtype), ") and python element size (",
349 retval.itemsize(), ")"));
350 }
351 }
352 size_t size = element_size(o_dtype);
353 memcpy(out, retval.data(), size * retval.size());
354 }
355 }
356
357 py::gil_scoped_release release;
358 }
359}
360
361
362std::map<std::string, std::vector<PyCustomOpFactory>>& PyOp_container() {
363 static std::map<std::string, std::vector<PyCustomOpFactory>> map_custom_opdef;
364 return map_custom_opdef;
365}
366
367void PyCustomOpDef::AddOp(const PyCustomOpDef* cod) {
368 // try to fetch the domain name from op_type firstly.
369 std::string op_domain = c_OpDomain;
370 std::string op = cod->op_type;
371 auto dm_pos = cod->op_type.find("::");
372 if (std::string::npos != dm_pos) {
373 op_domain = cod->op_type.substr(0, dm_pos);
374 op = cod->op_type.substr(dm_pos + 2, -1);
375 }
376
377 // No need to protect against concurrent access, GIL is doing that.
378 auto val = std::make_pair(op_domain, std::vector<PyCustomOpFactory>());
379 const auto [it_domain_op, success] = PyOp_container().insert(val);
380 assert(success || it_domain_op->second.length() > 0);
381 it_domain_op->second.emplace_back(PyCustomOpFactory(cod, op_domain, op));
382}
383
384const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t num) {
385 if (!EnablePyCustomOps(true)) {
386 EnablePyCustomOps(false);
387 return nullptr;
388 }
389
390 auto it = PyOp_container().find(c_OpDomain);
391 if (it != PyOp_container().end()) {
392 const std::vector<PyCustomOpFactory>& ref = it->second;
393 if (num < ref.size()) {
394 return ref.data() + num; }
395 }
396
397 return nullptr;
398}
399
400const OrtCustomOp* FetchPyCustomOps(size_t& num) {
401 auto ptr = PyCustomOpDef_FetchPyCustomOps(num);
402 if (ptr == nullptr) // For the breakpoint in debugging.
403 return nullptr;
404 return ptr;
405}
406
407bool EnablePyCustomOps(bool enabled) {
408 static bool f_pyop_enabled = true;
409 bool last = f_pyop_enabled;
410 f_pyop_enabled = enabled;
411 return last;
412}
413
414OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions* options, const OrtApi* ortApi){
415 OrtCustomOpDomain* domain = nullptr;
416 OrtStatus* status = nullptr;
417
418 for (auto const& val_pair: PyOp_container()) {
419 if (val_pair.first == c_OpDomain) {
420 continue; // Register this domain in the second iteration.
421 }
422
423 if (status = ortApi->CreateCustomOpDomain(val_pair.first.c_str(), &domain)) {
424 return status;
425 }
426
427 for (auto const& cop: val_pair.second) {
428 if (status = ortApi->CustomOpDomain_Add(domain, &cop)) {
429 return status;
430 }
431 }
432
433 if (status = ortApi->AddCustomOpDomain(options, domain)) {
434 return status;
435 }
436 }
437
438 return status;
439}
440
441static int init_numpy() {
442 import_array1(0);
443 return 0;
444}
445
446uint64_t hash_64(const std::string& str, uint64_t num_buckets, bool fast) {
447 if (fast) {
448 return Hash64Fast(str.c_str(), str.size()) % static_cast<uint64_t>(num_buckets);
449 }
450 return Hash64(str.c_str(), str.size()) % static_cast<uint64_t>(num_buckets);
451}
452
453void AddGlobalMethods(pybind11::module& m) {
454 m.def("hash_64", &hash_64, "Computes a uint64 hash for a string (from tensorflow).");
455 m.def("enable_py_op", &EnablePyCustomOps, "Enable or disable pyop functions.");
456 m.def("add_custom_op", [](const PyCustomOpDef& cod) { PyCustomOpDef::AddOp(&cod); }, "Add a PyOp Python object.");
457 m.def("default_opset_domain", []{return std::string(c_OpDomain);}, "return the default opset domain name.");
458}
459
460void AddObjectMethods(pybind11::module& m) {
461 pybind11::class_<PyCustomOpDef>(m, "PyCustomOpDef")
462 .def(pybind11::init<>())
463 .def_readwrite("op_type", &PyCustomOpDef::op_type)
464 .def_readwrite("obj_id", &PyCustomOpDef::obj_id)
465 .def_readwrite("input_types", &PyCustomOpDef::input_types)
466 .def_readwrite("output_types", &PyCustomOpDef::output_types)
467 .def_readwrite("attrs", &PyCustomOpDef::attrs)
468 .def_static("install_hooker", [](py::object obj) {
469 PyCustomOpDefImpl::op_invoker = std::make_unique<PyCustomOpDefImpl::callback_t>(obj); })
470 .def_readonly_static("undefined", &PyCustomOpDef::undefined)
471 .def_readonly_static("dt_float", &PyCustomOpDef::dt_float)
472 .def_readonly_static("dt_uint8", &PyCustomOpDef::dt_uint8)
473 .def_readonly_static("dt_int8", &PyCustomOpDef::dt_int8)
474 .def_readonly_static("dt_uint16", &PyCustomOpDef::dt_uint16)
475 .def_readonly_static("dt_int16", &PyCustomOpDef::dt_int16)
476 .def_readonly_static("dt_int32", &PyCustomOpDef::dt_int32)
477 .def_readonly_static("dt_int64", &PyCustomOpDef::dt_int64)
478 .def_readonly_static("dt_string", &PyCustomOpDef::dt_string)
479 .def_readonly_static("dt_bool", &PyCustomOpDef::dt_bool)
480 .def_readonly_static("dt_float16", &PyCustomOpDef::dt_float16)
481 .def_readonly_static("dt_double", &PyCustomOpDef::dt_double)
482 .def_readonly_static("dt_uint32", &PyCustomOpDef::dt_uint32)
483 .def_readonly_static("dt_uint64", &PyCustomOpDef::dt_uint64)
484 .def_readonly_static("dt_complex64", &PyCustomOpDef::dt_complex64)
485 .def_readonly_static("dt_complex128", &PyCustomOpDef::dt_complex128)
486 .def_readonly_static("dt_bfloat16", &PyCustomOpDef::dt_bfloat16);
487}
488
489PYBIND11_MODULE(_ortcustomops, m) {
490 m.doc() = "pybind11 stateful interface to ONNXRuntime-Extensions";
491
492 init_numpy();
493 AddGlobalMethods(m);
494 AddObjectMethods(m);
495}
496