microsoft/qdk

Public

mirrored fromhttps://github.com/microsoft/qdkAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.22.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

source/pip/tests/test_generic_estimator.py

265lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4import pytest
5import qsharp
6
7
8class SampleAlgorithm:
9 def __init__(self, *, qubits=10, depth=20, magic_states=[100]):
10 self.qubits = qubits
11 self.depth = depth
12 self.magic_states = magic_states
13
14 def logical_qubits(self):
15 return self.qubits
16
17 def logical_depth(self, budget):
18 return self.depth
19
20 def num_magic_states(self, budget, index):
21 return self.magic_states[index]
22
23
24class SampleCode:
25 def physical_qubits(self, param):
26 return 2 * param**2
27
28 def logical_qubits(self, param):
29 return 1
30
31 def logical_cycle_time(self, qubit, param):
32 return 6 * qubit["gate_time"] * param
33
34 def logical_error_rate(self, qubit, param):
35 return 0.03 * (qubit["error_rate"] / 0.01) ** ((param + 1) // 2)
36
37 def code_parameter_range(self):
38 return [1, 2, 3]
39
40 def code_parameter_cmp(self, qubit, p1, p2):
41 return -1 if p1 < p2 else (1 if p1 > p2 else 0)
42
43
44class SampleFactory:
45 def find_factories(self, code, qubit, target_error_rate):
46 assert isinstance(code, SampleCode)
47 assert isinstance(qubit, dict)
48 assert qubit == sample_qubit()
49
50 return [{"physical_qubits": 100, "duration": 1000}]
51
52
53class SampleFactoryBuilder:
54 def __init__(self):
55 # Key to index into magic gate error rate in qubit
56 self.gate_error = "error_rate"
57 self.max_rounds = 3
58
59 def distillation_units(self, code, qubit, max_code_parameter):
60 return [
61 {
62 "num_input_states": 15,
63 "physical_qubits": lambda _: 50,
64 "duration": lambda _: 500,
65 "output_error_rate": lambda input_error_rate: 35 * input_error_rate**3,
66 "failure_probability": lambda input_error_rate: 15 * input_error_rate,
67 }
68 ]
69
70
71def sample_qubit():
72 return {"gate_time": 50, "error_rate": 1e-4}
73
74
75def test_wrong_input():
76 pytest.raises(
77 AttributeError, qsharp.estimate_custom, 42, sample_qubit(), SampleCode()
78 )
79
80 # Catches missing methods in SampleAlgorithm
81 for method_name in ["logical_qubits", "logical_depth", "num_magic_states"]:
82 method = getattr(SampleAlgorithm, method_name)
83 delattr(SampleAlgorithm, method_name)
84 pytest.raises(
85 AttributeError,
86 qsharp.estimate_custom,
87 SampleAlgorithm(),
88 sample_qubit(),
89 SampleCode(),
90 )
91 setattr(SampleAlgorithm, method_name, method)
92
93 # Catches missing methods in SampleCode
94 for method_name in [
95 "physical_qubits",
96 "logical_qubits",
97 "logical_cycle_time",
98 "logical_error_rate",
99 "code_parameter_range",
100 "code_parameter_cmp",
101 ]:
102 method = getattr(SampleCode, method_name)
103 delattr(SampleCode, method_name)
104 pytest.raises(
105 AttributeError,
106 qsharp.estimate_custom,
107 SampleAlgorithm(),
108 sample_qubit(),
109 SampleCode(),
110 )
111 setattr(SampleCode, method_name, method)
112
113 # Catches wrong type for method
114 method = SampleAlgorithm.logical_qubits
115 SampleAlgorithm.logical_qubits = "not a method"
116 pytest.raises(
117 TypeError,
118 qsharp.estimate_custom,
119 SampleAlgorithm(),
120 sample_qubit(),
121 SampleCode(),
122 )
123 SampleAlgorithm.logical_qubits = method
124
125 # Catches wrong signature for method
126 method = SampleAlgorithm.logical_depth
127 SampleAlgorithm.logical_depth = lambda self, budget, extra: 20
128 pytest.raises(
129 RuntimeError,
130 qsharp.estimate_custom,
131 SampleAlgorithm(),
132 sample_qubit(),
133 SampleCode(),
134 )
135 SampleAlgorithm.logical_depth = lambda self: 20
136 pytest.raises(
137 RuntimeError,
138 qsharp.estimate_custom,
139 SampleAlgorithm(),
140 sample_qubit(),
141 SampleCode(),
142 )
143 SampleAlgorithm.logical_depth = method
144
145
146def test_estimate_without_factories():
147 result = qsharp.estimate_custom(SampleAlgorithm(), sample_qubit(), SampleCode())
148
149 assert len(result["factoryParts"]) == 0
150 assert len(result["layoutOverhead"]["numMagicStates"]) == 0
151 assert result["runtime"] == 18000
152 assert result["physicalQubits"] == 180
153
154 assert "executionStats" in result
155 assert "timeAlgorithm" in result["executionStats"]
156 assert "timeEstimation" in result["executionStats"]
157
158
159def test_with_single_factory():
160 result = qsharp.estimate_custom(
161 SampleAlgorithm(), sample_qubit(), SampleCode(), [SampleFactory()]
162 )
163 assert len(result["factoryParts"]) == 1
164 assert len(result["layoutOverhead"]["numMagicStates"]) == 1
165
166 assert "physical_qubits" in result["factoryParts"][0]["factory"]
167 assert "duration" in result["factoryParts"][0]["factory"]
168
169
170def test_with_multiple_factories():
171 result = qsharp.estimate_custom(
172 SampleAlgorithm(magic_states=[50, 100, 200]),
173 sample_qubit(),
174 SampleCode(),
175 [SampleFactory()] * 3,
176 )
177 assert len(result["factoryParts"]) == 3
178 assert len(result["layoutOverhead"]["numMagicStates"]) == 3
179
180 for factory_part in result["factoryParts"]:
181 assert "physical_qubits" in factory_part["factory"]
182 assert "duration" in factory_part["factory"]
183
184
185def test_with_factory_builder():
186 result = qsharp.estimate_custom(
187 SampleAlgorithm(),
188 sample_qubit(),
189 SampleCode(),
190 [SampleFactoryBuilder()],
191 )
192
193 assert len(result["factoryParts"]) == 1
194 assert len(result["layoutOverhead"]["numMagicStates"]) == 1
195
196 assert "physical_qubits" in result["factoryParts"][0]["factory"]
197 assert "duration" in result["factoryParts"][0]["factory"]
198
199
200def test_with_trivial_factory_unit():
201 def _trivial_distillation_unit(self, code, qubit, max_code_parameter):
202 return {
203 "num_input_states": 1,
204 "physical_qubits": lambda _: 1,
205 "duration": lambda _: qubit["gate_time"],
206 "output_error_rate": lambda input_error_rate: input_error_rate,
207 "failure_probability": lambda _: 0.0,
208 }
209
210 result = qsharp.estimate_custom(
211 SampleAlgorithm(),
212 {**sample_qubit(), "error_rate": 1e-6},
213 SampleCode(),
214 [SampleFactoryBuilder()],
215 )
216
217 # No override for special case, runtime is 500 ns
218 assert result["factoryParts"][0]["factory"]["duration"] == 500
219
220 # Apply override to return T gate directly if error rate is low enough
221 SampleFactoryBuilder.trivial_distillation_unit = _trivial_distillation_unit
222
223 result = qsharp.estimate_custom(
224 SampleAlgorithm(),
225 {**sample_qubit(), "error_rate": 1e-6},
226 SampleCode(),
227 [SampleFactoryBuilder()],
228 )
229
230 # With override, runtime is 50 ns
231 assert result["factoryParts"][0]["factory"]["duration"] == 50
232
233 del SampleFactoryBuilder.trivial_distillation_unit
234
235
236def test_prune_error_budget():
237 result = qsharp.estimate_custom(SampleAlgorithm(), sample_qubit(), SampleCode())
238
239 assert abs(result["errorBudget"]["logical"] - 0.01 / 3) < 1e-6
240 assert abs(result["errorBudget"]["rotations"] - 0.01 / 3) < 1e-6
241 assert abs(result["errorBudget"]["magic_states"] - 0.01 / 3) < 1e-6
242
243 result = qsharp.estimate_custom(
244 SampleAlgorithm(), sample_qubit(), SampleCode(), error_budget=0.1
245 )
246
247 assert abs(result["errorBudget"]["logical"] - 0.1 / 3) < 1e-6
248 assert abs(result["errorBudget"]["rotations"] - 0.1 / 3) < 1e-6
249 assert abs(result["errorBudget"]["magic_states"] - 0.1 / 3) < 1e-6
250
251 def _prune_error_budget(self, budget, _):
252 rotations = budget["rotations"]
253 budget["logical"] += rotations / 2
254 budget["magic_states"] += rotations / 2
255 budget["rotations"] = 0.0
256
257 SampleAlgorithm.prune_error_budget = _prune_error_budget
258
259 result = qsharp.estimate_custom(SampleAlgorithm(), sample_qubit(), SampleCode())
260
261 assert abs(result["errorBudget"]["logical"] - 0.01 / 2) < 1e-6
262 assert abs(result["errorBudget"]["rotations"]) < 1e-6
263 assert abs(result["errorBudget"]["magic_states"] - 0.01 / 2) < 1e-6
264
265 del SampleAlgorithm.prune_error_budget
266