microsoft/qdk

Public

mirrored fromhttps://github.com/microsoft/qdkAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
alex/testHarnessIntegTests

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

compiler/qsc_frontend/src/closure.rs

347lines · modecode

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use qsc_data_structures::{index_map::IndexMap, span::Span};
5use qsc_hir::{
6 assigner::Assigner,
7 hir::{
8 Block, CallableDecl, CallableKind, Expr, ExprKind, Ident, Mutability, NodeId, Pat, PatKind,
9 Res, SpecBody, SpecDecl, Stmt, StmtKind,
10 },
11 mut_visit::{self, MutVisitor},
12 ty::{Arrow, FunctorSetValue, Ty},
13 visit::{self, Visitor},
14};
15use rustc_hash::{FxHashMap, FxHashSet};
16use std::iter;
17
18pub(super) struct Lambda {
19 pub(super) kind: CallableKind,
20 pub(super) functors: FunctorSetValue,
21 pub(super) input: Pat,
22 pub(super) body: Expr,
23}
24
25pub(super) struct PartialApp {
26 pub(super) bindings: Vec<Stmt>,
27 pub(super) input: Pat,
28}
29
30struct VarFinder {
31 bindings: FxHashSet<NodeId>,
32 uses: FxHashSet<NodeId>,
33}
34
35impl VarFinder {
36 fn free_vars(&self) -> Vec<NodeId> {
37 let mut vars: Vec<_> = self.uses.difference(&self.bindings).copied().collect();
38 vars.sort_unstable();
39 vars
40 }
41}
42
43impl Visitor<'_> for VarFinder {
44 fn visit_expr(&mut self, expr: &Expr) {
45 match &expr.kind {
46 ExprKind::Closure(args, _) => self.uses.extend(args.iter().copied()),
47 &ExprKind::Var(Res::Local(id), _) => {
48 self.uses.insert(id);
49 }
50 _ => visit::walk_expr(self, expr),
51 }
52 }
53
54 fn visit_pat(&mut self, pat: &Pat) {
55 if let PatKind::Bind(name) = &pat.kind {
56 self.bindings.insert(name.id);
57 } else {
58 visit::walk_pat(self, pat);
59 }
60 }
61}
62
63struct VarReplacer<'a> {
64 substitutions: &'a FxHashMap<NodeId, NodeId>,
65}
66
67impl VarReplacer<'_> {
68 fn replace(&self, id: &mut NodeId) {
69 if let Some(&new_id) = self.substitutions.get(id) {
70 *id = new_id;
71 }
72 }
73}
74
75impl MutVisitor for VarReplacer<'_> {
76 fn visit_expr(&mut self, expr: &mut Expr) {
77 match &mut expr.kind {
78 ExprKind::Closure(args, _) => args.iter_mut().for_each(|arg| self.replace(arg)),
79 ExprKind::Var(Res::Local(id), _) => self.replace(id),
80 _ => mut_visit::walk_expr(self, expr),
81 }
82 }
83}
84
85pub(super) fn lift(
86 assigner: &mut Assigner,
87 locals: &IndexMap<NodeId, (Ident, Ty)>,
88 mut lambda: Lambda,
89 span: Span,
90) -> (Vec<NodeId>, CallableDecl) {
91 let mut finder = VarFinder {
92 bindings: FxHashSet::default(),
93 uses: FxHashSet::default(),
94 };
95 finder.visit_pat(&lambda.input);
96 finder.visit_expr(&lambda.body);
97
98 let free_vars = finder.free_vars();
99 let substitutions: FxHashMap<_, _> = free_vars
100 .iter()
101 .map(|&id| (id, assigner.next_node()))
102 .collect();
103
104 VarReplacer {
105 substitutions: &substitutions,
106 }
107 .visit_expr(&mut lambda.body);
108
109 let substituted_vars = free_vars.iter().filter_map(|&id| {
110 let &new_id = substitutions
111 .get(&id)
112 .expect("free variable should have substitution");
113 locals
114 .get(id)
115 .map(|original_ident| (new_id, original_ident.clone()))
116 });
117
118 let mut input = closure_input(substituted_vars, lambda.input, span);
119 assigner.visit_pat(&mut input);
120
121 let callable = CallableDecl {
122 id: assigner.next_node(),
123 span,
124 kind: lambda.kind,
125 name: Ident {
126 id: assigner.next_node(),
127 span,
128 name: "<lambda>".into(),
129 },
130 generics: Vec::new(),
131 input,
132 output: lambda.body.ty.clone(),
133 functors: lambda.functors,
134 body: SpecDecl {
135 id: assigner.next_node(),
136 span: lambda.body.span,
137 body: SpecBody::Impl(
138 None,
139 Block {
140 id: assigner.next_node(),
141 span: lambda.body.span,
142 ty: lambda.body.ty.clone(),
143 stmts: vec![Stmt {
144 id: assigner.next_node(),
145 span: lambda.body.span,
146 kind: StmtKind::Expr(lambda.body),
147 }],
148 },
149 ),
150 },
151 adj: None,
152 ctl: None,
153 ctl_adj: None,
154 attrs: Vec::default(),
155 };
156
157 (free_vars, callable)
158}
159
160pub(super) fn partial_app_block(
161 close: impl FnOnce(Lambda) -> ExprKind,
162 callee: Expr,
163 arg: Expr,
164 app: PartialApp,
165 arrow: Arrow,
166 span: Span,
167) -> Block {
168 let call = Expr {
169 id: NodeId::default(),
170 span,
171 ty: (*arrow.output).clone(),
172 kind: ExprKind::Call(Box::new(callee), Box::new(arg)),
173 };
174 let lambda = Lambda {
175 kind: arrow.kind,
176 functors: arrow
177 .functors
178 .expect_value("lambda type should have concrete functors"),
179 input: app.input,
180 body: call,
181 };
182 let closure = Expr {
183 id: NodeId::default(),
184 span,
185 ty: Ty::Arrow(Box::new(arrow.clone())),
186 kind: close(lambda),
187 };
188
189 let mut stmts = app.bindings;
190 stmts.push(Stmt {
191 id: NodeId::default(),
192 span,
193 kind: StmtKind::Expr(closure),
194 });
195 Block {
196 id: NodeId::default(),
197 span,
198 ty: Ty::Arrow(Box::new(arrow)),
199 stmts,
200 }
201}
202
203pub(super) fn partial_app_hole(
204 assigner: &mut Assigner,
205 locals: &mut IndexMap<NodeId, (Ident, Ty)>,
206 ty: Ty,
207 span: Span,
208) -> (Expr, PartialApp) {
209 let local_id = assigner.next_node();
210 let ident = Ident {
211 id: local_id,
212 span,
213 name: "hole".into(),
214 };
215
216 locals.insert(local_id, (ident.clone(), ty.clone()));
217
218 let app = PartialApp {
219 bindings: Vec::new(),
220 input: Pat {
221 id: assigner.next_node(),
222 span,
223 ty: ty.clone(),
224 kind: PatKind::Bind(ident),
225 },
226 };
227
228 let var = Expr {
229 id: assigner.next_node(),
230 span,
231 ty,
232 kind: ExprKind::Var(Res::Local(local_id), Vec::new()),
233 };
234
235 (var, app)
236}
237
238pub(super) fn partial_app_given(
239 assigner: &mut Assigner,
240 locals: &mut IndexMap<NodeId, (Ident, Ty)>,
241 arg: Expr,
242) -> (Expr, PartialApp) {
243 let local_id = assigner.next_node();
244 let span = arg.span;
245 let ident = Ident {
246 id: local_id,
247 span,
248 name: "arg".into(),
249 };
250
251 locals.insert(local_id, (ident.clone(), arg.ty.clone()));
252
253 let var = Expr {
254 id: assigner.next_node(),
255 span,
256 ty: arg.ty.clone(),
257 kind: ExprKind::Var(Res::Local(local_id), Vec::new()),
258 };
259
260 let binding_pat = Pat {
261 id: assigner.next_node(),
262 span,
263 ty: arg.ty.clone(),
264 kind: PatKind::Bind(ident),
265 };
266 let binding_stmt = Stmt {
267 id: assigner.next_node(),
268 span,
269 kind: StmtKind::Local(Mutability::Immutable, binding_pat, arg),
270 };
271 let app = PartialApp {
272 bindings: vec![binding_stmt],
273 input: Pat {
274 id: assigner.next_node(),
275 span,
276 ty: Ty::UNIT,
277 kind: PatKind::Tuple(Vec::new()),
278 },
279 };
280
281 (var, app)
282}
283
284pub(super) fn partial_app_tuple(
285 args: impl Iterator<Item = (Expr, PartialApp)>,
286 span: Span,
287) -> (Expr, PartialApp) {
288 let mut items = Vec::new();
289 let mut bindings = Vec::new();
290 let mut holes = Vec::new();
291 for (arg, mut app) in args {
292 items.push(arg);
293 bindings.append(&mut app.bindings);
294 if !matches!(&app.input.kind, PatKind::Tuple(items) if items.is_empty()) {
295 holes.push(app.input);
296 }
297 }
298
299 let input = if holes.len() == 1 {
300 holes.pop().expect("holes should have one element")
301 } else {
302 Pat {
303 id: NodeId::default(),
304 span,
305 ty: Ty::Tuple(holes.iter().map(|h| h.ty.clone()).collect()),
306 kind: PatKind::Tuple(holes),
307 }
308 };
309
310 let expr = Expr {
311 id: NodeId::default(),
312 span,
313 ty: Ty::Tuple(items.iter().map(|i| i.ty.clone()).collect()),
314 kind: ExprKind::Tuple(items),
315 };
316
317 (expr, PartialApp { bindings, input })
318}
319
320fn closure_input(
321 vars: impl IntoIterator<Item = (NodeId, (Ident, Ty))>,
322 input: Pat,
323 span: Span,
324) -> Pat {
325 let bindings: Vec<_> = vars
326 .into_iter()
327 .map(|(id, (ident, ty))| Pat {
328 id: NodeId::default(),
329 span: ident.span,
330 ty,
331 kind: PatKind::Bind(Ident { id, ..ident }),
332 })
333 .collect();
334
335 let tys = bindings
336 .iter()
337 .map(|p| p.ty.clone())
338 .chain(iter::once(input.ty.clone()))
339 .collect();
340
341 Pat {
342 id: NodeId::default(),
343 span,
344 ty: Ty::Tuple(tys),
345 kind: PatKind::Tuple(bindings.into_iter().chain(iter::once(input)).collect()),
346 }
347}
348