microsoft/qdk

Public

mirrored fromhttps://github.com/microsoft/qdkAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
alex/pythontelem

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

compiler/qsc_frontend/src/closure.rs

346lines · modecode

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT License.
3
4use qsc_data_structures::{index_map::IndexMap, span::Span};
5use qsc_hir::{
6 assigner::Assigner,
7 hir::{
8 Block, CallableDecl, CallableKind, Expr, ExprKind, Ident, Mutability, NodeId, Pat, PatKind,
9 Res, SpecBody, SpecDecl, Stmt, StmtKind,
10 },
11 mut_visit::{self, MutVisitor},
12 ty::{Arrow, FunctorSetValue, Ty},
13 visit::{self, Visitor},
14};
15use rustc_hash::{FxHashMap, FxHashSet};
16use std::iter;
17
18pub(super) struct Lambda {
19 pub(super) kind: CallableKind,
20 pub(super) functors: FunctorSetValue,
21 pub(super) input: Pat,
22 pub(super) body: Expr,
23}
24
25pub(super) struct PartialApp {
26 pub(super) bindings: Vec<Stmt>,
27 pub(super) input: Pat,
28}
29
30struct VarFinder {
31 bindings: FxHashSet<NodeId>,
32 uses: FxHashSet<NodeId>,
33}
34
35impl VarFinder {
36 fn free_vars(&self) -> Vec<NodeId> {
37 let mut vars: Vec<_> = self.uses.difference(&self.bindings).copied().collect();
38 vars.sort_unstable();
39 vars
40 }
41}
42
43impl Visitor<'_> for VarFinder {
44 fn visit_expr(&mut self, expr: &Expr) {
45 match &expr.kind {
46 ExprKind::Closure(args, _) => self.uses.extend(args.iter().copied()),
47 &ExprKind::Var(Res::Local(id), _) => {
48 self.uses.insert(id);
49 }
50 _ => visit::walk_expr(self, expr),
51 }
52 }
53
54 fn visit_pat(&mut self, pat: &Pat) {
55 if let PatKind::Bind(name) = &pat.kind {
56 self.bindings.insert(name.id);
57 } else {
58 visit::walk_pat(self, pat);
59 }
60 }
61}
62
63struct VarReplacer<'a> {
64 substitutions: &'a FxHashMap<NodeId, NodeId>,
65}
66
67impl VarReplacer<'_> {
68 fn replace(&self, id: &mut NodeId) {
69 if let Some(&new_id) = self.substitutions.get(id) {
70 *id = new_id;
71 }
72 }
73}
74
75impl MutVisitor for VarReplacer<'_> {
76 fn visit_expr(&mut self, expr: &mut Expr) {
77 match &mut expr.kind {
78 ExprKind::Closure(args, _) => args.iter_mut().for_each(|arg| self.replace(arg)),
79 ExprKind::Var(Res::Local(id), _) => self.replace(id),
80 _ => mut_visit::walk_expr(self, expr),
81 }
82 }
83}
84
85pub(super) fn lift(
86 assigner: &mut Assigner,
87 locals: &IndexMap<NodeId, (Ident, Ty)>,
88 mut lambda: Lambda,
89 span: Span,
90) -> (Vec<NodeId>, CallableDecl) {
91 let mut finder = VarFinder {
92 bindings: FxHashSet::default(),
93 uses: FxHashSet::default(),
94 };
95 finder.visit_pat(&lambda.input);
96 finder.visit_expr(&lambda.body);
97
98 let free_vars = finder.free_vars();
99 let substitutions: FxHashMap<_, _> = free_vars
100 .iter()
101 .map(|&id| (id, assigner.next_node()))
102 .collect();
103
104 VarReplacer {
105 substitutions: &substitutions,
106 }
107 .visit_expr(&mut lambda.body);
108
109 let substituted_vars = free_vars.iter().filter_map(|&id| {
110 let &new_id = substitutions
111 .get(&id)
112 .expect("free variable should have substitution");
113 locals
114 .get(id)
115 .map(|original_ident| (new_id, original_ident.clone()))
116 });
117
118 let mut input = closure_input(substituted_vars, lambda.input, span);
119 assigner.visit_pat(&mut input);
120
121 let callable = CallableDecl {
122 id: assigner.next_node(),
123 span,
124 kind: lambda.kind,
125 name: Ident {
126 id: assigner.next_node(),
127 span,
128 name: "lambda".into(),
129 },
130 generics: Vec::new(),
131 input,
132 output: lambda.body.ty.clone(),
133 functors: lambda.functors,
134 body: SpecDecl {
135 id: assigner.next_node(),
136 span: lambda.body.span,
137 body: SpecBody::Impl(
138 None,
139 Block {
140 id: assigner.next_node(),
141 span: lambda.body.span,
142 ty: lambda.body.ty.clone(),
143 stmts: vec![Stmt {
144 id: assigner.next_node(),
145 span: lambda.body.span,
146 kind: StmtKind::Expr(lambda.body),
147 }],
148 },
149 ),
150 },
151 adj: None,
152 ctl: None,
153 ctl_adj: None,
154 };
155
156 (free_vars, callable)
157}
158
159pub(super) fn partial_app_block(
160 close: impl FnOnce(Lambda) -> ExprKind,
161 callee: Expr,
162 arg: Expr,
163 app: PartialApp,
164 arrow: Arrow,
165 span: Span,
166) -> Block {
167 let call = Expr {
168 id: NodeId::default(),
169 span,
170 ty: (*arrow.output).clone(),
171 kind: ExprKind::Call(Box::new(callee), Box::new(arg)),
172 };
173 let lambda = Lambda {
174 kind: arrow.kind,
175 functors: arrow
176 .functors
177 .expect_value("lambda type should have concrete functors"),
178 input: app.input,
179 body: call,
180 };
181 let closure = Expr {
182 id: NodeId::default(),
183 span,
184 ty: Ty::Arrow(Box::new(arrow.clone())),
185 kind: close(lambda),
186 };
187
188 let mut stmts = app.bindings;
189 stmts.push(Stmt {
190 id: NodeId::default(),
191 span,
192 kind: StmtKind::Expr(closure),
193 });
194 Block {
195 id: NodeId::default(),
196 span,
197 ty: Ty::Arrow(Box::new(arrow)),
198 stmts,
199 }
200}
201
202pub(super) fn partial_app_hole(
203 assigner: &mut Assigner,
204 locals: &mut IndexMap<NodeId, (Ident, Ty)>,
205 ty: Ty,
206 span: Span,
207) -> (Expr, PartialApp) {
208 let local_id = assigner.next_node();
209 let ident = Ident {
210 id: local_id,
211 span,
212 name: "hole".into(),
213 };
214
215 locals.insert(local_id, (ident.clone(), ty.clone()));
216
217 let app = PartialApp {
218 bindings: Vec::new(),
219 input: Pat {
220 id: assigner.next_node(),
221 span,
222 ty: ty.clone(),
223 kind: PatKind::Bind(ident),
224 },
225 };
226
227 let var = Expr {
228 id: assigner.next_node(),
229 span,
230 ty,
231 kind: ExprKind::Var(Res::Local(local_id), Vec::new()),
232 };
233
234 (var, app)
235}
236
237pub(super) fn partial_app_given(
238 assigner: &mut Assigner,
239 locals: &mut IndexMap<NodeId, (Ident, Ty)>,
240 arg: Expr,
241) -> (Expr, PartialApp) {
242 let local_id = assigner.next_node();
243 let span = arg.span;
244 let ident = Ident {
245 id: local_id,
246 span,
247 name: "arg".into(),
248 };
249
250 locals.insert(local_id, (ident.clone(), arg.ty.clone()));
251
252 let var = Expr {
253 id: assigner.next_node(),
254 span,
255 ty: arg.ty.clone(),
256 kind: ExprKind::Var(Res::Local(local_id), Vec::new()),
257 };
258
259 let binding_pat = Pat {
260 id: assigner.next_node(),
261 span,
262 ty: arg.ty.clone(),
263 kind: PatKind::Bind(ident),
264 };
265 let binding_stmt = Stmt {
266 id: assigner.next_node(),
267 span,
268 kind: StmtKind::Local(Mutability::Immutable, binding_pat, arg),
269 };
270 let app = PartialApp {
271 bindings: vec![binding_stmt],
272 input: Pat {
273 id: assigner.next_node(),
274 span,
275 ty: Ty::UNIT,
276 kind: PatKind::Tuple(Vec::new()),
277 },
278 };
279
280 (var, app)
281}
282
283pub(super) fn partial_app_tuple(
284 args: impl Iterator<Item = (Expr, PartialApp)>,
285 span: Span,
286) -> (Expr, PartialApp) {
287 let mut items = Vec::new();
288 let mut bindings = Vec::new();
289 let mut holes = Vec::new();
290 for (arg, mut app) in args {
291 items.push(arg);
292 bindings.append(&mut app.bindings);
293 if !matches!(&app.input.kind, PatKind::Tuple(items) if items.is_empty()) {
294 holes.push(app.input);
295 }
296 }
297
298 let input = if holes.len() == 1 {
299 holes.pop().expect("holes should have one element")
300 } else {
301 Pat {
302 id: NodeId::default(),
303 span,
304 ty: Ty::Tuple(holes.iter().map(|h| h.ty.clone()).collect()),
305 kind: PatKind::Tuple(holes),
306 }
307 };
308
309 let expr = Expr {
310 id: NodeId::default(),
311 span,
312 ty: Ty::Tuple(items.iter().map(|i| i.ty.clone()).collect()),
313 kind: ExprKind::Tuple(items),
314 };
315
316 (expr, PartialApp { bindings, input })
317}
318
319fn closure_input(
320 vars: impl IntoIterator<Item = (NodeId, (Ident, Ty))>,
321 input: Pat,
322 span: Span,
323) -> Pat {
324 let bindings: Vec<_> = vars
325 .into_iter()
326 .map(|(id, (ident, ty))| Pat {
327 id: NodeId::default(),
328 span: ident.span,
329 ty,
330 kind: PatKind::Bind(Ident { id, ..ident }),
331 })
332 .collect();
333
334 let tys = bindings
335 .iter()
336 .map(|p| p.ty.clone())
337 .chain(iter::once(input.ty.clone()))
338 .collect();
339
340 Pat {
341 id: NodeId::default(),
342 span,
343 ty: Ty::Tuple(tys),
344 kind: PatKind::Tuple(bindings.into_iter().chain(iter::once(input)).collect()),
345 }
346}
347