microsoft/qdk

Public

mirrored from https://github.com/microsoft/qdkAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.25.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

source/pip/qsharp/_device/_atom/_decomp.py

511lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4from pyqir import (
5 FloatConstant,
6 const,
7 Function,
8 FunctionType,
9 Type,
10 qubit_type,
11 result_type,
12 result,
13 Context,
14 Linkage,
15 QirModuleVisitor,
16 required_num_results,
17)
18from math import pi
19from ._utils import TOLERANCE
20
21
22class DecomposeMultiQubitToCZ(QirModuleVisitor):
23 """
24 Decomposes all multi-qubit gates to CZ gates and single qubit gates.
25 """
26
27 h_func: Function
28 s_func: Function
29 sadj_func: Function
30 t_func: Function
31 tadj_func: Function
32 rz_func: Function
33 cz_func: Function
34
35 def _on_module(self, module):
36 void = Type.void(module.context)
37 qubit_ty = qubit_type(module.context)
38 self.double_ty = Type.double(module.context)
39 # Find or create all the needed functions.
40 for func in module.functions:
41 match func.name:
42 case "__quantum__qis__h__body":
43 self.h_func = func
44 case "__quantum__qis__s__body":
45 self.s_func = func
46 case "__quantum__qis__s__adj":
47 self.sadj_func = func
48 case "__quantum__qis__t__body":
49 self.t_func = func
50 case "__quantum__qis__t__adj":
51 self.tadj_func = func
52 case "__quantum__qis__rz__body":
53 self.rz_func = func
54 case "__quantum__qis__cz__body":
55 self.cz_func = func
56 if not hasattr(self, "h_func"):
57 self.h_func = Function(
58 FunctionType(void, [qubit_ty]),
59 Linkage.EXTERNAL,
60 "__quantum__qis__h__body",
61 module,
62 )
63 if not hasattr(self, "s_func"):
64 self.s_func = Function(
65 FunctionType(void, [qubit_ty]),
66 Linkage.EXTERNAL,
67 "__quantum__qis__s__body",
68 module,
69 )
70 if not hasattr(self, "sadj_func"):
71 self.sadj_func = Function(
72 FunctionType(void, [qubit_ty]),
73 Linkage.EXTERNAL,
74 "__quantum__qis__s__adj",
75 module,
76 )
77 if not hasattr(self, "t_func"):
78 self.t_func = Function(
79 FunctionType(void, [qubit_ty]),
80 Linkage.EXTERNAL,
81 "__quantum__qis__t__body",
82 module,
83 )
84 if not hasattr(self, "tadj_func"):
85 self.tadj_func = Function(
86 FunctionType(void, [qubit_ty]),
87 Linkage.EXTERNAL,
88 "__quantum__qis__t__adj",
89 module,
90 )
91 if not hasattr(self, "rz_func"):
92 self.rz_func = Function(
93 FunctionType(void, [self.double_ty, qubit_ty]),
94 Linkage.EXTERNAL,
95 "__quantum__qis__rz__body",
96 module,
97 )
98 if not hasattr(self, "cz_func"):
99 self.cz_func = Function(
100 FunctionType(void, [qubit_ty, qubit_ty]),
101 Linkage.EXTERNAL,
102 "__quantum__qis__cz__body",
103 module,
104 )
105 super()._on_module(module)
106
107 def _on_qis_ccx(self, call, ctrl1, ctrl2, target):
108 self.builder.insert_before(call)
109 self.builder.call(self.h_func, [target])
110 self.builder.call(self.tadj_func, [ctrl1])
111 self.builder.call(self.tadj_func, [ctrl2])
112 self.builder.call(self.h_func, [ctrl1])
113 self.builder.call(self.cz_func, [target, ctrl1])
114 self.builder.call(self.h_func, [ctrl1])
115 self.builder.call(self.t_func, [ctrl1])
116 self.builder.call(self.h_func, [target])
117 self.builder.call(self.cz_func, [ctrl2, target])
118 self.builder.call(self.h_func, [target])
119 self.builder.call(self.h_func, [ctrl1])
120 self.builder.call(self.cz_func, [ctrl2, ctrl1])
121 self.builder.call(self.h_func, [ctrl1])
122 self.builder.call(self.t_func, [target])
123 self.builder.call(self.tadj_func, [ctrl1])
124 self.builder.call(self.h_func, [target])
125 self.builder.call(self.cz_func, [ctrl2, target])
126 self.builder.call(self.h_func, [target])
127 self.builder.call(self.h_func, [ctrl1])
128 self.builder.call(self.cz_func, [target, ctrl1])
129 self.builder.call(self.h_func, [ctrl1])
130 self.builder.call(self.tadj_func, [target])
131 self.builder.call(self.t_func, [ctrl1])
132 self.builder.call(self.h_func, [ctrl1])
133 self.builder.call(self.cz_func, [ctrl2, ctrl1])
134 self.builder.call(self.h_func, [ctrl1])
135 self.builder.call(self.h_func, [target])
136 call.erase()
137
138 def _on_qis_cx(self, call, ctrl, target):
139 self.builder.insert_before(call)
140 self.builder.call(self.h_func, [target])
141 self.builder.call(self.cz_func, [ctrl, target])
142 self.builder.call(self.h_func, [target])
143 call.erase()
144
145 def _on_qis_cy(self, call, ctrl, target):
146 self.builder.insert_before(call)
147 self.builder.call(self.sadj_func, [target])
148 self.builder.call(self.h_func, [target])
149 self.builder.call(self.cz_func, [ctrl, target])
150 self.builder.call(self.h_func, [target])
151 self.builder.call(self.s_func, [target])
152 call.erase()
153
154 def _on_qis_rxx(self, call, angle, target1, target2):
155 self.builder.insert_before(call)
156 self.builder.call(self.h_func, [target2])
157 self.builder.call(self.cz_func, [target2, target1])
158 self.builder.call(self.h_func, [target1])
159 self.builder.call(self.rz_func, [angle, target1])
160 self.builder.call(self.h_func, [target1])
161 self.builder.call(self.cz_func, [target2, target1])
162 self.builder.call(self.h_func, [target2])
163 call.erase()
164
165 def _on_qis_ryy(self, call, angle, target1, target2):
166 self.builder.insert_before(call)
167 self.builder.call(self.sadj_func, [target1])
168 self.builder.call(self.sadj_func, [target2])
169 self.builder.call(self.h_func, [target2])
170 self.builder.call(self.cz_func, [target2, target1])
171 self.builder.call(self.h_func, [target1])
172 self.builder.call(self.rz_func, [angle, target1])
173 self.builder.call(self.h_func, [target1])
174 self.builder.call(self.cz_func, [target2, target1])
175 self.builder.call(self.h_func, [target2])
176 self.builder.call(self.s_func, [target2])
177 self.builder.call(self.s_func, [target1])
178 call.erase()
179
180 def _on_qis_rzz(self, call, angle, target1, target2):
181 self.builder.insert_before(call)
182 self.builder.call(self.h_func, [target1])
183 self.builder.call(self.cz_func, [target2, target1])
184 self.builder.call(self.h_func, [target1])
185 self.builder.call(self.rz_func, [angle, target1])
186 self.builder.call(self.h_func, [target1])
187 self.builder.call(self.cz_func, [target2, target1])
188 self.builder.call(self.h_func, [target1])
189 call.erase()
190
191 def _on_qis_swap(self, call, target1, target2):
192 self.builder.insert_before(call)
193 self.builder.call(self.h_func, [target2])
194 self.builder.call(self.cz_func, [target1, target2])
195 self.builder.call(self.h_func, [target2])
196 self.builder.call(self.h_func, [target1])
197 self.builder.call(self.cz_func, [target2, target1])
198 self.builder.call(self.h_func, [target1])
199 self.builder.call(self.h_func, [target2])
200 self.builder.call(self.cz_func, [target1, target2])
201 self.builder.call(self.h_func, [target2])
202 call.erase()
203
204
205class DecomposeSingleRotationToRz(QirModuleVisitor):
206 """
207 Decomposes all single qubit rotations to Rz gates.
208 """
209
210 h_func: Function
211 s_func: Function
212 sadj_func: Function
213 rz_func: Function
214
215 def _on_module(self, module):
216 void = Type.void(module.context)
217 qubit_ty = qubit_type(module.context)
218 self.double_ty = Type.double(module.context)
219 # Find or create all the needed functions.
220 for func in module.functions:
221 match func.name:
222 case "__quantum__qis__h__body":
223 self.h_func = func
224 case "__quantum__qis__s__body":
225 self.s_func = func
226 case "__quantum__qis__s__adj":
227 self.sadj_func = func
228 case "__quantum__qis__rz__body":
229 self.rz_func = func
230 if not hasattr(self, "h_func"):
231 self.h_func = Function(
232 FunctionType(void, [qubit_ty]),
233 Linkage.EXTERNAL,
234 "__quantum__qis__h__body",
235 module,
236 )
237 if not hasattr(self, "s_func"):
238 self.s_func = Function(
239 FunctionType(void, [qubit_ty]),
240 Linkage.EXTERNAL,
241 "__quantum__qis__s__body",
242 module,
243 )
244 if not hasattr(self, "sadj_func"):
245 self.sadj_func = Function(
246 FunctionType(void, [qubit_ty]),
247 Linkage.EXTERNAL,
248 "__quantum__qis__s__adj",
249 module,
250 )
251 if not hasattr(self, "rz_func"):
252 self.rz_func = Function(
253 FunctionType(void, [self.double_ty, qubit_ty]),
254 Linkage.EXTERNAL,
255 "__quantum__qis__rz__body",
256 module,
257 )
258 super()._on_module(module)
259
260 def _on_qis_rx(self, call, angle, target):
261 self.builder.insert_before(call)
262 self.builder.call(self.h_func, [target])
263 self.builder.call(
264 self.rz_func,
265 [angle, target],
266 )
267 self.builder.call(self.h_func, [target])
268 call.erase()
269
270 def _on_qis_ry(self, call, angle, target):
271 self.builder.insert_before(call)
272 self.builder.call(self.sadj_func, [target])
273 self.builder.call(self.h_func, [target])
274 self.builder.call(
275 self.rz_func,
276 [angle, target],
277 )
278 self.builder.call(self.h_func, [target])
279 self.builder.call(self.s_func, [target])
280 call.erase()
281
282
283class DecomposeSingleQubitToRzSX(QirModuleVisitor):
284 """
285 Decomposes all single qubit gates to Rz and Sx gates.
286 """
287
288 sx_func: Function
289 rz_func: Function
290
291 def _on_module(self, module):
292 void = Type.void(module.context)
293 qubit_ty = qubit_type(module.context)
294 self.double_ty = Type.double(module.context)
295 # Find or create all the needed functions.
296 for func in module.functions:
297 match func.name:
298 case "__quantum__qis__sx__body":
299 self.sx_func = func
300 case "__quantum__qis__rz__body":
301 self.rz_func = func
302 if not hasattr(self, "sx_func"):
303 self.sx_func = Function(
304 FunctionType(void, [qubit_ty]),
305 Linkage.EXTERNAL,
306 "__quantum__qis__sx__body",
307 module,
308 )
309 if not hasattr(self, "rz_func"):
310 self.rz_func = Function(
311 FunctionType(void, [self.double_ty, qubit_ty]),
312 Linkage.EXTERNAL,
313 "__quantum__qis__rz__body",
314 module,
315 )
316 super()._on_module(module)
317
318 def _on_qis_h(self, call, target):
319 self.builder.insert_before(call)
320 self.builder.call(
321 self.rz_func,
322 [const(self.double_ty, pi / 2), target],
323 )
324 self.builder.call(self.sx_func, [target])
325 self.builder.call(
326 self.rz_func,
327 [const(self.double_ty, pi / 2), target],
328 )
329 call.erase()
330
331 def _on_qis_s(self, call, target):
332 self.builder.insert_before(call)
333 self.builder.call(
334 self.rz_func,
335 [const(self.double_ty, pi / 2), target],
336 )
337 call.erase()
338
339 def _on_qis_s_adj(self, call, target):
340 self.builder.insert_before(call)
341 self.builder.call(
342 self.rz_func,
343 [const(self.double_ty, -pi / 2), target],
344 )
345 call.erase()
346
347 def _on_qis_t(self, call, target):
348 self.builder.insert_before(call)
349 self.builder.call(
350 self.rz_func,
351 [const(self.double_ty, pi / 4), target],
352 )
353 call.erase()
354
355 def _on_qis_t_adj(self, call, target):
356 self.builder.insert_before(call)
357 self.builder.call(
358 self.rz_func,
359 [const(self.double_ty, -pi / 4), target],
360 )
361 call.erase()
362
363 def _on_qis_x(self, call, target):
364 self.builder.insert_before(call)
365 self.builder.call(self.sx_func, [target])
366 self.builder.call(self.sx_func, [target])
367 call.erase()
368
369 def _on_qis_y(self, call, target):
370 self.builder.insert_before(call)
371 self.builder.call(self.sx_func, [target])
372 self.builder.call(self.sx_func, [target])
373 self.builder.call(
374 self.rz_func,
375 [const(self.double_ty, pi), target],
376 )
377 call.erase()
378
379 def _on_qis_z(self, call, target):
380 self.builder.insert_before(call)
381 self.builder.call(
382 self.rz_func,
383 [const(self.double_ty, pi), target],
384 )
385 call.erase()
386
387
388class DecomposeRzAnglesToCliffordGates(QirModuleVisitor):
389 """
390 Ensure that the module only contains Clifford gates instead of rotation angles.
391 """
392
393 THREE_PI_OVER_2 = 3 * pi / 2
394 PI_OVER_2 = pi / 2
395 TWO_PI = 2 * pi
396
397 z_func: Function
398 s_func: Function
399 sadj_func: Function
400
401 def _on_module(self, module):
402 void = Type.void(module.context)
403 qubit_ty = qubit_type(module.context)
404 self.double_ty = Type.double(module.context)
405 # Find or create all the needed functions.
406 for func in module.functions:
407 match func.name:
408 case "__quantum__qis__s__body":
409 self.s_func = func
410 case "__quantum__qis__s__adj":
411 self.sadj_func = func
412 case "__quantum__qis__z__body":
413 self.z_func = func
414
415 if not hasattr(self, "s_func"):
416 self.s_func = Function(
417 FunctionType(void, [qubit_ty]),
418 Linkage.EXTERNAL,
419 "__quantum__qis__s__body",
420 module,
421 )
422 if not hasattr(self, "sadj_func"):
423 self.sadj_func = Function(
424 FunctionType(void, [qubit_ty]),
425 Linkage.EXTERNAL,
426 "__quantum__qis__s__adj",
427 module,
428 )
429 if not hasattr(self, "z_func"):
430 self.z_func = Function(
431 FunctionType(void, [qubit_ty]),
432 Linkage.EXTERNAL,
433 "__quantum__qis__z__body",
434 module,
435 )
436
437 super()._on_module(module)
438
439 def _on_qis_rz(self, call, angle, target):
440 if not isinstance(angle, FloatConstant):
441 raise ValueError("Angle used in RZ must be a constant")
442 angle = angle.value
443
444 self.builder.insert_before(call)
445
446 if (
447 abs(angle - self.THREE_PI_OVER_2) < TOLERANCE
448 or abs(angle + self.PI_OVER_2) < TOLERANCE
449 ):
450 self.builder.call(self.sadj_func, [target])
451 elif abs(angle - pi) < TOLERANCE or abs(angle + pi) < TOLERANCE:
452 self.builder.call(self.z_func, [target])
453 elif (
454 abs(angle - self.PI_OVER_2) < TOLERANCE
455 or abs(angle + self.THREE_PI_OVER_2) < TOLERANCE
456 ):
457 self.builder.call(self.s_func, [target])
458 elif (
459 angle < TOLERANCE
460 or abs(angle - self.TWO_PI) < TOLERANCE
461 or abs(angle + self.TWO_PI) < TOLERANCE
462 ):
463 # I, drop it
464 pass
465 else:
466 raise ValueError(
467 f"Angle {angle} used in RZ is not a Clifford compatible rotation angle"
468 )
469
470 call.erase()
471
472
473class ReplaceResetWithMResetZ(QirModuleVisitor):
474 """
475 Replaces all reset operations with a call to mresetz using a new, ignored result identifier.
476 """
477
478 context: Context
479 mresetz_func: Function
480 next_result_id: int
481
482 def _on_module(self, module):
483 self.context = module.context
484 void = Type.void(self.context)
485 qubit_ty = qubit_type(self.context)
486 result_ty = result_type(self.context)
487 # Find or create the intrinsic mresetz function
488 for func in module.functions:
489 match func.name:
490 case "__quantum__qis__mresetz__body":
491 self.mresetz_func = func
492 if not hasattr(self, "mresetz_func"):
493 self.mresetz_func = Function(
494 FunctionType(void, [qubit_ty, result_ty]),
495 Linkage.EXTERNAL,
496 "__quantum__qis__mresetz__body",
497 module,
498 )
499 super()._on_module(module)
500
501 def _on_function(self, function):
502 self.next_result_id = required_num_results(function) or 0
503 super()._on_function(function)
504
505 def _on_qis_reset(self, call, target):
506 self.builder.insert_before(call)
507 # Create a new result identifier to ignore the measurement result
508 result_id = result(self.context, self.next_result_id)
509 self.next_result_id += 1
510 self.builder.call(self.mresetz_func, [target, result_id])
511 call.erase()
512