microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f74770feed077546874ed7e66d1aba9e2509fea9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/op_equal_impl.hpp

177lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#pragma once
4#include <vector>
5#include <string>
6#include "string_utils.h"
7#include "string_tensor.h"
8
9template <typename T1, typename T2, typename T3>
10class BroadcastIteratorRight {
11 public:
12 BroadcastIteratorRight(const std::vector<int64_t>& shape1,
13 const std::vector<int64_t>& shape2,
14 const T1* p1, const T2* p2, T3* p3) : p1_(p1), p2_(p2), p3_(p3), shape1_(shape1) {
15 if (shape2.size() > shape1.size())
16 throw std::runtime_error("shape2 must have less dimensions than shape1");
17 shape2_.resize(shape1_.size());
18 cum_shape2_.resize(shape1_.size());
19 total_ = 1;
20 for (size_t i = 0; i < shape1_.size(); ++i) {
21 total_ *= shape1[i];
22 if (i >= shape2.size()) {
23 shape2_[i] = 1;
24 continue;
25 } else {
26 shape2_[i] = shape2[i];
27 }
28 if (shape2[i] != 1 && shape1[i] != shape2[i]) {
29 throw std::runtime_error(MakeString(
30 "Cannot broadcast dimension ", i, " left:", shape1[i], " right:", shape2[i]));
31 }
32 }
33 cum_shape2_[shape2_.size() - 1] = 1;
34 for (size_t i = 1; i < shape1_.size(); ++i) {
35 cum_shape2_[shape1_.size() - i - 1] = cum_shape2_[shape1_.size() - i] * shape2_[shape1_.size() - i];
36 }
37 }
38
39 struct BroadcastIteratorRightState {
40 const BroadcastIteratorRight<T1, T2, T3>* parent;
41 std::vector<int64_t> index1;
42 const T1* p1;
43 const T1* end_;
44 const T2* p2;
45 T3* p3;
46 size_t last;
47 int dim;
48
49 void init(const BroadcastIteratorRight<T1, T2, T3>& p) {
50 parent = &p;
51 p1 = p.p1_;
52 p2 = p.p2_;
53 p3 = p.p3_;
54 end_ = p.p1_ + p.total_;
55 index1.resize(p.shape1_.size(), 0);
56 last = index1.size() - 1;
57 }
58
59 bool end() {
60 return p1 == end_;
61 }
62
63 void next() {
64 ++index1[last];
65 ++p1;
66 ++p3;
67 if (parent->shape2_[last] != 1) {
68 ++p2;
69 }
70 dim = static_cast<int>(last);
71 while (dim > 0 && index1[dim] >= parent->shape1_[dim]) {
72 index1[dim] = 0;
73 if (parent->shape2_[dim] != 1) {
74 p2 -= parent->cum_shape2_[dim] * parent->shape2_[dim];
75 }
76 --dim;
77 ++index1[dim];
78 if (parent->shape2_[dim] != 1) {
79 p2 += parent->cum_shape2_[dim];
80 }
81 }
82 }
83
84 template <typename TCMP>
85 void loop(TCMP& cmp, BroadcastIteratorRightState& it, int64_t pos = 0) {
86 if (pos != 0)
87 throw std::runtime_error("Not implemented yet.");
88 while (!end()) {
89 *p3 = cmp(*p1, *p2);
90 next();
91 }
92 }
93 };
94
95 protected:
96 std::vector<int64_t> shape1_;
97 std::vector<int64_t> shape2_;
98 std::vector<int64_t> cum_shape2_;
99 int64_t total_;
100 const T1* p1_;
101 const T2* p2_;
102 T3* p3_;
103};
104
105template <typename T>
106class Compare {
107 public:
108 inline bool operator()(const T& s1, const T& s2) const;
109};
110
111template <>
112inline bool Compare<std::string>::operator()(const std::string& s1, const std::string& s2) const {
113 return s1.compare(s2) == 0;
114}
115
116template <typename T>
117void KernelEqual_Compute(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelContext* context) {
118 // Setup inputs
119 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
120 const T* X = ort_.GetTensorData<T>(input_X);
121 const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
122 const T* Y = ort_.GetTensorData<T>(input_Y);
123
124 // Setup output
125 OrtTensorDimensions dimensions_x(ort_, input_X);
126 OrtTensorDimensions dimensions_y(ort_, input_Y);
127 Compare<T> cmp;
128
129 typename BroadcastIteratorRight<T, T, bool>::BroadcastIteratorRightState state;
130 if (dimensions_x.Size() >= dimensions_y.Size()) {
131 OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_x.data(), dimensions_x.size());
132 bool* out = ort_.GetTensorMutableData<bool>(v);
133 BroadcastIteratorRight<T, T, bool> iter(dimensions_x, dimensions_y, X, Y, out);
134 state.init(iter);
135 state.loop(cmp, state);
136 } else {
137 // Operator Equal is commutative.
138 OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_y.data(), dimensions_y.size());
139 bool* out = ort_.GetTensorMutableData<bool>(v);
140 BroadcastIteratorRight<T, T, bool> iter(dimensions_y, dimensions_x, Y, X, out);
141 state.init(iter);
142 state.loop(cmp, state);
143 }
144}
145
146template <>
147void KernelEqual_Compute<std::string>(const OrtApi& api, Ort::CustomOpApi& ort_, OrtKernelContext* context) {
148 // Setup inputs
149 const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0);
150 const OrtValue* input_Y = ort_.KernelContext_GetInput(context, 1);
151 std::vector<std::string> X, Y;
152 GetTensorMutableDataString(api, ort_, context, input_X, X);
153 GetTensorMutableDataString(api, ort_, context, input_Y, Y);
154
155 // Setup output
156 OrtTensorDimensions dimensions_x(ort_, input_X);
157 OrtTensorDimensions dimensions_y(ort_, input_Y);
158 Compare<std::string> cmp;
159
160 typename BroadcastIteratorRight<std::string, std::string, bool>::BroadcastIteratorRightState state;
161 if (dimensions_x.Size() >= dimensions_y.Size()) {
162 OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_x.data(), dimensions_x.size());
163 bool* out = ort_.GetTensorMutableData<bool>(v);
164 BroadcastIteratorRight<std::string, std::string, bool> iter(
165 dimensions_x, dimensions_y, X.data(), Y.data(), out);
166 state.init(iter);
167 state.loop(cmp, state);
168 } else {
169 // Operator Equal is commutative.
170 OrtValue* v = ort_.KernelContext_GetOutput(context, 0, dimensions_y.data(), dimensions_y.size());
171 bool* out = ort_.GetTensorMutableData<bool>(v);
172 BroadcastIteratorRight<std::string, std::string, bool> iter(
173 dimensions_y, dimensions_x, Y.data(), X.data(), out);
174 state.init(iter);
175 state.loop(cmp, state);
176 }
177}
178