microsoft/qdk
Publicmirrored fromhttps://github.com/microsoft/qdkAvailable
compiler/qsc_frontend/src/closure.rs
347lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | use qsc_data_structures::{index_map::IndexMap, span::Span}; |
| 5 | use qsc_hir::{ |
| 6 | assigner::Assigner, |
| 7 | hir::{ |
| 8 | Block, CallableDecl, CallableKind, Expr, ExprKind, Ident, Mutability, NodeId, Pat, PatKind, |
| 9 | Res, SpecBody, SpecDecl, Stmt, StmtKind, |
| 10 | }, |
| 11 | mut_visit::{self, MutVisitor}, |
| 12 | ty::{Arrow, FunctorSetValue, Ty}, |
| 13 | visit::{self, Visitor}, |
| 14 | }; |
| 15 | use rustc_hash::{FxHashMap, FxHashSet}; |
| 16 | use std::iter; |
| 17 | |
| 18 | pub(super) struct Lambda { |
| 19 | pub(super) kind: CallableKind, |
| 20 | pub(super) functors: FunctorSetValue, |
| 21 | pub(super) input: Pat, |
| 22 | pub(super) body: Expr, |
| 23 | } |
| 24 | |
| 25 | pub(super) struct PartialApp { |
| 26 | pub(super) bindings: Vec<Stmt>, |
| 27 | pub(super) input: Pat, |
| 28 | } |
| 29 | |
| 30 | struct VarFinder { |
| 31 | bindings: FxHashSet<NodeId>, |
| 32 | uses: FxHashSet<NodeId>, |
| 33 | } |
| 34 | |
| 35 | impl VarFinder { |
| 36 | fn free_vars(&self) -> Vec<NodeId> { |
| 37 | let mut vars: Vec<_> = self.uses.difference(&self.bindings).copied().collect(); |
| 38 | vars.sort_unstable(); |
| 39 | vars |
| 40 | } |
| 41 | } |
| 42 | |
| 43 | impl Visitor<'_> for VarFinder { |
| 44 | fn visit_expr(&mut self, expr: &Expr) { |
| 45 | match &expr.kind { |
| 46 | ExprKind::Closure(args, _) => self.uses.extend(args.iter().copied()), |
| 47 | &ExprKind::Var(Res::Local(id), _) => { |
| 48 | self.uses.insert(id); |
| 49 | } |
| 50 | _ => visit::walk_expr(self, expr), |
| 51 | } |
| 52 | } |
| 53 | |
| 54 | fn visit_pat(&mut self, pat: &Pat) { |
| 55 | if let PatKind::Bind(name) = &pat.kind { |
| 56 | self.bindings.insert(name.id); |
| 57 | } else { |
| 58 | visit::walk_pat(self, pat); |
| 59 | } |
| 60 | } |
| 61 | } |
| 62 | |
| 63 | struct VarReplacer<'a> { |
| 64 | substitutions: &'a FxHashMap<NodeId, NodeId>, |
| 65 | } |
| 66 | |
| 67 | impl VarReplacer<'_> { |
| 68 | fn replace(&self, id: &mut NodeId) { |
| 69 | if let Some(&new_id) = self.substitutions.get(id) { |
| 70 | *id = new_id; |
| 71 | } |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | impl MutVisitor for VarReplacer<'_> { |
| 76 | fn visit_expr(&mut self, expr: &mut Expr) { |
| 77 | match &mut expr.kind { |
| 78 | ExprKind::Closure(args, _) => args.iter_mut().for_each(|arg| self.replace(arg)), |
| 79 | ExprKind::Var(Res::Local(id), _) => self.replace(id), |
| 80 | _ => mut_visit::walk_expr(self, expr), |
| 81 | } |
| 82 | } |
| 83 | } |
| 84 | |
| 85 | pub(super) fn lift( |
| 86 | assigner: &mut Assigner, |
| 87 | locals: &IndexMap<NodeId, (Ident, Ty)>, |
| 88 | mut lambda: Lambda, |
| 89 | span: Span, |
| 90 | ) -> (Vec<NodeId>, CallableDecl) { |
| 91 | let mut finder = VarFinder { |
| 92 | bindings: FxHashSet::default(), |
| 93 | uses: FxHashSet::default(), |
| 94 | }; |
| 95 | finder.visit_pat(&lambda.input); |
| 96 | finder.visit_expr(&lambda.body); |
| 97 | |
| 98 | let free_vars = finder.free_vars(); |
| 99 | let substitutions: FxHashMap<_, _> = free_vars |
| 100 | .iter() |
| 101 | .map(|&id| (id, assigner.next_node())) |
| 102 | .collect(); |
| 103 | |
| 104 | VarReplacer { |
| 105 | substitutions: &substitutions, |
| 106 | } |
| 107 | .visit_expr(&mut lambda.body); |
| 108 | |
| 109 | let substituted_vars = free_vars.iter().filter_map(|&id| { |
| 110 | let &new_id = substitutions |
| 111 | .get(&id) |
| 112 | .expect("free variable should have substitution"); |
| 113 | locals |
| 114 | .get(id) |
| 115 | .map(|original_ident| (new_id, original_ident.clone())) |
| 116 | }); |
| 117 | |
| 118 | let mut input = closure_input(substituted_vars, lambda.input, span); |
| 119 | assigner.visit_pat(&mut input); |
| 120 | |
| 121 | let callable = CallableDecl { |
| 122 | id: assigner.next_node(), |
| 123 | span, |
| 124 | kind: lambda.kind, |
| 125 | name: Ident { |
| 126 | id: assigner.next_node(), |
| 127 | span, |
| 128 | name: "<lambda>".into(), |
| 129 | }, |
| 130 | generics: Vec::new(), |
| 131 | input, |
| 132 | output: lambda.body.ty.clone(), |
| 133 | functors: lambda.functors, |
| 134 | body: SpecDecl { |
| 135 | id: assigner.next_node(), |
| 136 | span: lambda.body.span, |
| 137 | body: SpecBody::Impl( |
| 138 | None, |
| 139 | Block { |
| 140 | id: assigner.next_node(), |
| 141 | span: lambda.body.span, |
| 142 | ty: lambda.body.ty.clone(), |
| 143 | stmts: vec![Stmt { |
| 144 | id: assigner.next_node(), |
| 145 | span: lambda.body.span, |
| 146 | kind: StmtKind::Expr(lambda.body), |
| 147 | }], |
| 148 | }, |
| 149 | ), |
| 150 | }, |
| 151 | adj: None, |
| 152 | ctl: None, |
| 153 | ctl_adj: None, |
| 154 | attrs: Vec::default(), |
| 155 | }; |
| 156 | |
| 157 | (free_vars, callable) |
| 158 | } |
| 159 | |
| 160 | pub(super) fn partial_app_block( |
| 161 | close: impl FnOnce(Lambda) -> ExprKind, |
| 162 | callee: Expr, |
| 163 | arg: Expr, |
| 164 | app: PartialApp, |
| 165 | arrow: Arrow, |
| 166 | span: Span, |
| 167 | ) -> Block { |
| 168 | let call = Expr { |
| 169 | id: NodeId::default(), |
| 170 | span, |
| 171 | ty: (*arrow.output).clone(), |
| 172 | kind: ExprKind::Call(Box::new(callee), Box::new(arg)), |
| 173 | }; |
| 174 | let lambda = Lambda { |
| 175 | kind: arrow.kind, |
| 176 | functors: arrow |
| 177 | .functors |
| 178 | .expect_value("lambda type should have concrete functors"), |
| 179 | input: app.input, |
| 180 | body: call, |
| 181 | }; |
| 182 | let closure = Expr { |
| 183 | id: NodeId::default(), |
| 184 | span, |
| 185 | ty: Ty::Arrow(Box::new(arrow.clone())), |
| 186 | kind: close(lambda), |
| 187 | }; |
| 188 | |
| 189 | let mut stmts = app.bindings; |
| 190 | stmts.push(Stmt { |
| 191 | id: NodeId::default(), |
| 192 | span, |
| 193 | kind: StmtKind::Expr(closure), |
| 194 | }); |
| 195 | Block { |
| 196 | id: NodeId::default(), |
| 197 | span, |
| 198 | ty: Ty::Arrow(Box::new(arrow)), |
| 199 | stmts, |
| 200 | } |
| 201 | } |
| 202 | |
| 203 | pub(super) fn partial_app_hole( |
| 204 | assigner: &mut Assigner, |
| 205 | locals: &mut IndexMap<NodeId, (Ident, Ty)>, |
| 206 | ty: Ty, |
| 207 | span: Span, |
| 208 | ) -> (Expr, PartialApp) { |
| 209 | let local_id = assigner.next_node(); |
| 210 | let ident = Ident { |
| 211 | id: local_id, |
| 212 | span, |
| 213 | name: "hole".into(), |
| 214 | }; |
| 215 | |
| 216 | locals.insert(local_id, (ident.clone(), ty.clone())); |
| 217 | |
| 218 | let app = PartialApp { |
| 219 | bindings: Vec::new(), |
| 220 | input: Pat { |
| 221 | id: assigner.next_node(), |
| 222 | span, |
| 223 | ty: ty.clone(), |
| 224 | kind: PatKind::Bind(ident), |
| 225 | }, |
| 226 | }; |
| 227 | |
| 228 | let var = Expr { |
| 229 | id: assigner.next_node(), |
| 230 | span, |
| 231 | ty, |
| 232 | kind: ExprKind::Var(Res::Local(local_id), Vec::new()), |
| 233 | }; |
| 234 | |
| 235 | (var, app) |
| 236 | } |
| 237 | |
| 238 | pub(super) fn partial_app_given( |
| 239 | assigner: &mut Assigner, |
| 240 | locals: &mut IndexMap<NodeId, (Ident, Ty)>, |
| 241 | arg: Expr, |
| 242 | ) -> (Expr, PartialApp) { |
| 243 | let local_id = assigner.next_node(); |
| 244 | let span = arg.span; |
| 245 | let ident = Ident { |
| 246 | id: local_id, |
| 247 | span, |
| 248 | name: "arg".into(), |
| 249 | }; |
| 250 | |
| 251 | locals.insert(local_id, (ident.clone(), arg.ty.clone())); |
| 252 | |
| 253 | let var = Expr { |
| 254 | id: assigner.next_node(), |
| 255 | span, |
| 256 | ty: arg.ty.clone(), |
| 257 | kind: ExprKind::Var(Res::Local(local_id), Vec::new()), |
| 258 | }; |
| 259 | |
| 260 | let binding_pat = Pat { |
| 261 | id: assigner.next_node(), |
| 262 | span, |
| 263 | ty: arg.ty.clone(), |
| 264 | kind: PatKind::Bind(ident), |
| 265 | }; |
| 266 | let binding_stmt = Stmt { |
| 267 | id: assigner.next_node(), |
| 268 | span, |
| 269 | kind: StmtKind::Local(Mutability::Immutable, binding_pat, arg), |
| 270 | }; |
| 271 | let app = PartialApp { |
| 272 | bindings: vec![binding_stmt], |
| 273 | input: Pat { |
| 274 | id: assigner.next_node(), |
| 275 | span, |
| 276 | ty: Ty::UNIT, |
| 277 | kind: PatKind::Tuple(Vec::new()), |
| 278 | }, |
| 279 | }; |
| 280 | |
| 281 | (var, app) |
| 282 | } |
| 283 | |
| 284 | pub(super) fn partial_app_tuple( |
| 285 | args: impl Iterator<Item = (Expr, PartialApp)>, |
| 286 | span: Span, |
| 287 | ) -> (Expr, PartialApp) { |
| 288 | let mut items = Vec::new(); |
| 289 | let mut bindings = Vec::new(); |
| 290 | let mut holes = Vec::new(); |
| 291 | for (arg, mut app) in args { |
| 292 | items.push(arg); |
| 293 | bindings.append(&mut app.bindings); |
| 294 | if !matches!(&app.input.kind, PatKind::Tuple(items) if items.is_empty()) { |
| 295 | holes.push(app.input); |
| 296 | } |
| 297 | } |
| 298 | |
| 299 | let input = if holes.len() == 1 { |
| 300 | holes.pop().expect("holes should have one element") |
| 301 | } else { |
| 302 | Pat { |
| 303 | id: NodeId::default(), |
| 304 | span, |
| 305 | ty: Ty::Tuple(holes.iter().map(|h| h.ty.clone()).collect()), |
| 306 | kind: PatKind::Tuple(holes), |
| 307 | } |
| 308 | }; |
| 309 | |
| 310 | let expr = Expr { |
| 311 | id: NodeId::default(), |
| 312 | span, |
| 313 | ty: Ty::Tuple(items.iter().map(|i| i.ty.clone()).collect()), |
| 314 | kind: ExprKind::Tuple(items), |
| 315 | }; |
| 316 | |
| 317 | (expr, PartialApp { bindings, input }) |
| 318 | } |
| 319 | |
| 320 | fn closure_input( |
| 321 | vars: impl IntoIterator<Item = (NodeId, (Ident, Ty))>, |
| 322 | input: Pat, |
| 323 | span: Span, |
| 324 | ) -> Pat { |
| 325 | let bindings: Vec<_> = vars |
| 326 | .into_iter() |
| 327 | .map(|(id, (ident, ty))| Pat { |
| 328 | id: NodeId::default(), |
| 329 | span: ident.span, |
| 330 | ty, |
| 331 | kind: PatKind::Bind(Ident { id, ..ident }), |
| 332 | }) |
| 333 | .collect(); |
| 334 | |
| 335 | let tys = bindings |
| 336 | .iter() |
| 337 | .map(|p| p.ty.clone()) |
| 338 | .chain(iter::once(input.ty.clone())) |
| 339 | .collect(); |
| 340 | |
| 341 | Pat { |
| 342 | id: NodeId::default(), |
| 343 | span, |
| 344 | ty: Ty::Tuple(tys), |
| 345 | kind: PatKind::Tuple(bindings.into_iter().chain(iter::once(input)).collect()), |
| 346 | } |
| 347 | } |
| 348 | |