microsoft/qdk

Public

mirrored from https://github.com/microsoft/qdkAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.19.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

source/compiler/qsc_frontend/src/closure.rs

348lines · modecode

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use qsc_data_structures::{index_map::IndexMap, span::Span};
5use qsc_hir::{
6 assigner::Assigner,
7 hir::{
8 Block, CallableDecl, CallableKind, Expr, ExprKind, Ident, Mutability, NodeId, Pat, PatKind,
9 Res, SpecBody, SpecDecl, Stmt, StmtKind,
10 },
11 mut_visit::{self, MutVisitor},
12 ty::{Arrow, FunctorSetValue, Ty},
13 visit::{self, Visitor},
14};
15use rustc_hash::{FxHashMap, FxHashSet};
16use std::{iter, rc::Rc};
17
18pub(super) struct Lambda {
19 pub(super) kind: CallableKind,
20 pub(super) functors: FunctorSetValue,
21 pub(super) input: Pat,
22 pub(super) body: Expr,
23}
24
25pub(super) struct PartialApp {
26 pub(super) bindings: Vec<Stmt>,
27 pub(super) input: Pat,
28}
29
30struct VarFinder {
31 bindings: FxHashSet<NodeId>,
32 uses: FxHashSet<NodeId>,
33}
34
35impl VarFinder {
36 fn free_vars(&self) -> Vec<NodeId> {
37 let mut vars: Vec<_> = self.uses.difference(&self.bindings).copied().collect();
38 vars.sort_unstable();
39 vars
40 }
41}
42
43impl Visitor<'_> for VarFinder {
44 fn visit_expr(&mut self, expr: &Expr) {
45 match &expr.kind {
46 ExprKind::Closure(args, _) => self.uses.extend(args.iter().copied()),
47 &ExprKind::Var(Res::Local(id), _) => {
48 self.uses.insert(id);
49 }
50 _ => visit::walk_expr(self, expr),
51 }
52 }
53
54 fn visit_pat(&mut self, pat: &Pat) {
55 if let PatKind::Bind(name) = &pat.kind {
56 self.bindings.insert(name.id);
57 } else {
58 visit::walk_pat(self, pat);
59 }
60 }
61}
62
63struct VarReplacer<'a> {
64 substitutions: &'a FxHashMap<NodeId, NodeId>,
65}
66
67impl VarReplacer<'_> {
68 fn replace(&self, id: &mut NodeId) {
69 if let Some(&new_id) = self.substitutions.get(id) {
70 *id = new_id;
71 }
72 }
73}
74
75impl MutVisitor for VarReplacer<'_> {
76 fn visit_expr(&mut self, expr: &mut Expr) {
77 match &mut expr.kind {
78 ExprKind::Closure(args, _) => args.iter_mut().for_each(|arg| self.replace(arg)),
79 ExprKind::Var(Res::Local(id), _) => self.replace(id),
80 _ => mut_visit::walk_expr(self, expr),
81 }
82 }
83}
84
85pub(super) fn lift(
86 assigner: &mut Assigner,
87 locals: &IndexMap<NodeId, (Ident, Ty)>,
88 mut lambda: Lambda,
89 span: Span,
90) -> (Vec<NodeId>, CallableDecl) {
91 let mut finder = VarFinder {
92 bindings: FxHashSet::default(),
93 uses: FxHashSet::default(),
94 };
95 finder.visit_pat(&lambda.input);
96 finder.visit_expr(&lambda.body);
97
98 let free_vars = finder.free_vars();
99 let substitutions: FxHashMap<_, _> = free_vars
100 .iter()
101 .map(|&id| (id, assigner.next_node()))
102 .collect();
103
104 VarReplacer {
105 substitutions: &substitutions,
106 }
107 .visit_expr(&mut lambda.body);
108
109 let substituted_vars = free_vars.iter().filter_map(|&id| {
110 let &new_id = substitutions
111 .get(&id)
112 .expect("free variable should have substitution");
113 locals
114 .get(id)
115 .map(|original_ident| (new_id, original_ident.clone()))
116 });
117
118 let mut input = closure_input(substituted_vars, lambda.input, span);
119 assigner.visit_pat(&mut input);
120
121 let callable = CallableDecl {
122 id: assigner.next_node(),
123 span,
124 kind: lambda.kind,
125 name: Ident {
126 id: assigner.next_node(),
127 span,
128 name: "<lambda>".into(),
129 },
130 generics: Vec::new(),
131 input,
132 output: lambda.body.ty.clone(),
133 functors: lambda.functors,
134 body: SpecDecl {
135 id: assigner.next_node(),
136 span: lambda.body.span,
137 body: SpecBody::Impl(
138 None,
139 Block {
140 id: assigner.next_node(),
141 span: lambda.body.span,
142 ty: lambda.body.ty.clone(),
143 stmts: vec![Stmt {
144 id: assigner.next_node(),
145 span: lambda.body.span,
146 kind: StmtKind::Expr(lambda.body),
147 }],
148 },
149 ),
150 },
151 adj: None,
152 ctl: None,
153 ctl_adj: None,
154 attrs: Vec::default(),
155 };
156
157 (free_vars, callable)
158}
159
160pub(super) fn partial_app_block(
161 close: impl FnOnce(Lambda) -> ExprKind,
162 callee: Expr,
163 arg: Expr,
164 app: PartialApp,
165 arrow: Rc<Arrow>,
166 span: Span,
167) -> Block {
168 let call = Expr {
169 id: NodeId::default(),
170 span,
171 ty: arrow.output.borrow().clone(),
172 kind: ExprKind::Call(Box::new(callee), Box::new(arg)),
173 };
174 let lambda = Lambda {
175 kind: arrow.kind,
176 functors: arrow
177 .functors
178 .borrow()
179 .expect_value("lambda type should have concrete functors"),
180 input: app.input,
181 body: call,
182 };
183 let closure = Expr {
184 id: NodeId::default(),
185 span,
186 ty: Ty::Arrow(arrow.clone()),
187 kind: close(lambda),
188 };
189
190 let mut stmts = app.bindings;
191 stmts.push(Stmt {
192 id: NodeId::default(),
193 span,
194 kind: StmtKind::Expr(closure),
195 });
196 Block {
197 id: NodeId::default(),
198 span,
199 ty: Ty::Arrow(arrow),
200 stmts,
201 }
202}
203
204pub(super) fn partial_app_hole(
205 assigner: &mut Assigner,
206 locals: &mut IndexMap<NodeId, (Ident, Ty)>,
207 ty: Ty,
208 span: Span,
209) -> (Expr, PartialApp) {
210 let local_id = assigner.next_node();
211 let ident = Ident {
212 id: local_id,
213 span,
214 name: "hole".into(),
215 };
216
217 locals.insert(local_id, (ident.clone(), ty.clone()));
218
219 let app = PartialApp {
220 bindings: Vec::new(),
221 input: Pat {
222 id: assigner.next_node(),
223 span,
224 ty: ty.clone(),
225 kind: PatKind::Bind(ident),
226 },
227 };
228
229 let var = Expr {
230 id: assigner.next_node(),
231 span,
232 ty,
233 kind: ExprKind::Var(Res::Local(local_id), Vec::new()),
234 };
235
236 (var, app)
237}
238
239pub(super) fn partial_app_given(
240 assigner: &mut Assigner,
241 locals: &mut IndexMap<NodeId, (Ident, Ty)>,
242 arg: Expr,
243) -> (Expr, PartialApp) {
244 let local_id = assigner.next_node();
245 let span = arg.span;
246 let ident = Ident {
247 id: local_id,
248 span,
249 name: "arg".into(),
250 };
251
252 locals.insert(local_id, (ident.clone(), arg.ty.clone()));
253
254 let var = Expr {
255 id: assigner.next_node(),
256 span,
257 ty: arg.ty.clone(),
258 kind: ExprKind::Var(Res::Local(local_id), Vec::new()),
259 };
260
261 let binding_pat = Pat {
262 id: assigner.next_node(),
263 span,
264 ty: arg.ty.clone(),
265 kind: PatKind::Bind(ident),
266 };
267 let binding_stmt = Stmt {
268 id: assigner.next_node(),
269 span,
270 kind: StmtKind::Local(Mutability::Immutable, binding_pat, arg),
271 };
272 let app = PartialApp {
273 bindings: vec![binding_stmt],
274 input: Pat {
275 id: assigner.next_node(),
276 span,
277 ty: Ty::UNIT,
278 kind: PatKind::Tuple(Vec::new()),
279 },
280 };
281
282 (var, app)
283}
284
285pub(super) fn partial_app_tuple(
286 args: impl Iterator<Item = (Expr, PartialApp)>,
287 span: Span,
288) -> (Expr, PartialApp) {
289 let mut items = Vec::new();
290 let mut bindings = Vec::new();
291 let mut holes = Vec::new();
292 for (arg, mut app) in args {
293 items.push(arg);
294 bindings.append(&mut app.bindings);
295 if !matches!(&app.input.kind, PatKind::Tuple(items) if items.is_empty()) {
296 holes.push(app.input);
297 }
298 }
299
300 let input = if holes.len() == 1 {
301 holes.pop().expect("holes should have one element")
302 } else {
303 Pat {
304 id: NodeId::default(),
305 span,
306 ty: Ty::Tuple(holes.iter().map(|h| h.ty.clone()).collect()),
307 kind: PatKind::Tuple(holes),
308 }
309 };
310
311 let expr = Expr {
312 id: NodeId::default(),
313 span,
314 ty: Ty::Tuple(items.iter().map(|i| i.ty.clone()).collect()),
315 kind: ExprKind::Tuple(items),
316 };
317
318 (expr, PartialApp { bindings, input })
319}
320
321fn closure_input(
322 vars: impl IntoIterator<Item = (NodeId, (Ident, Ty))>,
323 input: Pat,
324 span: Span,
325) -> Pat {
326 let bindings: Vec<_> = vars
327 .into_iter()
328 .map(|(id, (ident, ty))| Pat {
329 id: NodeId::default(),
330 span: ident.span,
331 ty,
332 kind: PatKind::Bind(Ident { id, ..ident }),
333 })
334 .collect();
335
336 let tys = bindings
337 .iter()
338 .map(|p| p.ty.clone())
339 .chain(iter::once(input.ty.clone()))
340 .collect();
341
342 Pat {
343 id: NodeId::default(),
344 span,
345 ty: Ty::Tuple(tys),
346 kind: PatKind::Tuple(bindings.into_iter().chain(iter::once(input)).collect()),
347 }
348}
349