microsoft/qdk
Publicmirrored fromhttps://github.com/microsoft/qdkAvailable
source/language_service/src/code_action/wrapper_refactor.rs
326lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | // Wrapper refactor code action logic: generates a zero-parameter wrapper operation |
| 5 | // that supplies default / placeholder values for an existing operation's parameters. |
| 6 | |
| 7 | #[cfg(test)] |
| 8 | mod tests; |
| 9 | |
| 10 | use qsc::hir::{ |
| 11 | CallableKind, ItemKind, PatKind, |
| 12 | ty::{Prim, Ty}, |
| 13 | }; |
| 14 | use qsc::{ |
| 15 | Span, |
| 16 | line_column::{Encoding, Range}, |
| 17 | }; |
| 18 | |
| 19 | use crate::{ |
| 20 | compilation::Compilation, |
| 21 | protocol::{CodeAction, CodeActionKind, TextEdit, WorkspaceEdit}, |
| 22 | }; |
| 23 | |
| 24 | pub(crate) fn operation_refactors( |
| 25 | compilation: &Compilation, |
| 26 | source_name: &str, |
| 27 | span: Span, |
| 28 | encoding: Encoding, |
| 29 | ) -> Vec<CodeAction> { |
| 30 | let mut code_actions = Vec::new(); |
| 31 | let user_unit = compilation.user_unit(); |
| 32 | let package = &user_unit.package; |
| 33 | let source_map = &user_unit.sources; |
| 34 | let source = source_map |
| 35 | .find_by_name(source_name) |
| 36 | .expect("source should exist"); |
| 37 | let source_span = compilation.package_span_of_source(source_name); |
| 38 | |
| 39 | for (_, item) in package.items.iter() { |
| 40 | if !source_span.contains(item.span.lo) || span.intersection(&item.span).is_none() { |
| 41 | continue; |
| 42 | } |
| 43 | if let ItemKind::Callable(decl) = &item.kind { |
| 44 | if decl.kind != CallableKind::Operation |
| 45 | || decl.input.ty == Ty::UNIT |
| 46 | || decl.name.name.as_ref() == "<lambda>" |
| 47 | { |
| 48 | continue; // only operations with non-empty params |
| 49 | } |
| 50 | |
| 51 | // Determine indentation using source-local offset (package offset minus source base). |
| 52 | let local_lo = item.span.lo - source.offset; |
| 53 | let indent = line_indentation(&source.contents, local_lo); |
| 54 | let body_indent = if indent.contains('\t') { |
| 55 | format!("{indent}\t") |
| 56 | } else { |
| 57 | format!("{indent} ") |
| 58 | }; |
| 59 | |
| 60 | let original_name = decl.name.name.as_ref(); |
| 61 | let wrapper_name = generate_unique_wrapper_name(package, original_name); |
| 62 | |
| 63 | let (decl_lines, call_args) = build_param_decls_and_call_args(&decl.input); |
| 64 | |
| 65 | let call_args_joined = if call_args.is_empty() { |
| 66 | String::new() |
| 67 | } else { |
| 68 | call_args.join(", ") |
| 69 | }; |
| 70 | |
| 71 | let return_ty = decl.output.display(); |
| 72 | let return_is_unit = decl.output == Ty::UNIT; |
| 73 | |
| 74 | let call_line = if return_is_unit { |
| 75 | format!("{body_indent}{original_name}({call_args_joined});") |
| 76 | } else { |
| 77 | format!("{body_indent}return {original_name}({call_args_joined});") |
| 78 | }; |
| 79 | |
| 80 | let mut body_lines = Vec::new(); |
| 81 | if !decl_lines.is_empty() { |
| 82 | body_lines.push(format!( |
| 83 | "{body_indent}// TODO: Fill out the values for the parameters" |
| 84 | )); |
| 85 | body_lines.extend(decl_lines.iter().map(|decl| format!("{body_indent}{decl}"))); |
| 86 | body_lines.push(String::new()); // blank line |
| 87 | } |
| 88 | body_lines.push(format!("{body_indent}// Call original operation")); |
| 89 | body_lines.push(call_line); |
| 90 | |
| 91 | // We intentionally do NOT prefix the first line with `indent` because the insertion point |
| 92 | // inherits the existing line's leading whitespace. We DO append `{indent}` after the blank line |
| 93 | // so that the original operation keeps its indentation after the inserted block. |
| 94 | let newline = detect_newline(&source.contents, local_lo as usize); |
| 95 | let wrapper_text = format!( |
| 96 | "operation {wrapper_name}() : {return_ty} {{{newline}{}{newline}{indent}}}{newline}{newline}{indent}", |
| 97 | body_lines.join(newline) |
| 98 | ); |
| 99 | |
| 100 | // Insert immediately above the original operation: use zero-length span at item.span.lo |
| 101 | let insert_span = Span { |
| 102 | lo: local_lo, |
| 103 | hi: local_lo, |
| 104 | }; |
| 105 | let edit_range = Range::from_span(encoding, &source.contents, &insert_span); |
| 106 | |
| 107 | code_actions.push(CodeAction { |
| 108 | title: format!("Generate wrapper with default arguments for {original_name}"), |
| 109 | edit: Some(WorkspaceEdit { |
| 110 | changes: vec![( |
| 111 | source_name.to_string(), |
| 112 | vec![TextEdit { |
| 113 | new_text: wrapper_text, |
| 114 | range: edit_range, |
| 115 | }], |
| 116 | )], |
| 117 | }), |
| 118 | kind: Some(CodeActionKind::Refactor), |
| 119 | is_preferred: None, |
| 120 | }); |
| 121 | } |
| 122 | } |
| 123 | code_actions |
| 124 | } |
| 125 | |
| 126 | // Generate a wrapper name that does not clash with existing items in the same package (simple heuristic). |
| 127 | fn generate_unique_wrapper_name(package: &qsc::hir::Package, base: &str) -> String { |
| 128 | // New naming convention: <BaseName>WithDefaults, with numeric suffixes if needed. |
| 129 | let mut candidate = format!("{base}WithDefaults"); |
| 130 | let mut counter = 2; |
| 131 | while package.items.iter().any(|(_, item)| match &item.kind { |
| 132 | ItemKind::Callable(decl) => decl.name.name.as_ref() == candidate, |
| 133 | _ => false, |
| 134 | }) { |
| 135 | candidate = format!("{base}WithDefaults{counter}"); |
| 136 | counter += 1; |
| 137 | } |
| 138 | candidate |
| 139 | } |
| 140 | |
| 141 | // Build declarations and call arguments preserving tuple structure. |
| 142 | // Returns (declaration lines, call argument expressions list at top-level) |
| 143 | fn build_param_decls_and_call_args(pat: &qsc::hir::Pat) -> (Vec<String>, Vec<String>) { |
| 144 | let mut decls = Vec::new(); |
| 145 | let call_args = match &pat.kind { |
| 146 | PatKind::Tuple(items) => { |
| 147 | let mut args = Vec::new(); |
| 148 | for item in items { |
| 149 | args.push(build_pattern_expr(item, &mut decls)); |
| 150 | } |
| 151 | args |
| 152 | } |
| 153 | _ => vec![build_pattern_expr(pat, &mut decls)], |
| 154 | }; |
| 155 | (decls, call_args) |
| 156 | } |
| 157 | |
| 158 | // Recursively build an expression for a pattern, pushing any needed declarations (let/use) into decls. |
| 159 | fn build_pattern_expr(pat: &qsc::hir::Pat, decls: &mut Vec<String>) -> String { |
| 160 | match &pat.kind { |
| 161 | PatKind::Err | PatKind::Discard => "_".to_string(), |
| 162 | PatKind::Tuple(items) => { |
| 163 | let parts: Vec<String> = items.iter().map(|p| build_pattern_expr(p, decls)).collect(); |
| 164 | format!("({})", parts.join(", ")) |
| 165 | } |
| 166 | PatKind::Bind(ident) => build_binding_expr(ident.name.as_ref(), &pat.ty, decls), |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | fn build_binding_expr(name: &str, ty: &Ty, decls: &mut Vec<String>) -> String { |
| 171 | match ty { |
| 172 | Ty::Prim(Prim::Qubit) => { |
| 173 | decls.push(format!("use {name} = Qubit();")); |
| 174 | name.to_string() |
| 175 | } |
| 176 | Ty::Array(inner) if matches!(**inner, Ty::Prim(Prim::Qubit)) => { |
| 177 | decls.push(format!("use {name} = Qubit[1];")); |
| 178 | name.to_string() |
| 179 | } |
| 180 | Ty::Tuple(items) => { |
| 181 | let mut qubit_counter = 0u32; |
| 182 | let mut qubit_reg_counter = 0u32; |
| 183 | let mut deferred_todos = Vec::new(); |
| 184 | let tuple_expr = build_tuple_literal( |
| 185 | name, |
| 186 | items, |
| 187 | decls, |
| 188 | &mut qubit_counter, |
| 189 | &mut qubit_reg_counter, |
| 190 | &mut deferred_todos, |
| 191 | ); |
| 192 | // Place any deferred TODO comments before the binding so they aren't interleaved with allocations. |
| 193 | decls.extend(deferred_todos); |
| 194 | decls.push(format!("let {name} = {tuple_expr};")); |
| 195 | name.to_string() |
| 196 | } |
| 197 | _ => { |
| 198 | let (default_expr, comment) = default_value_for_type(ty); |
| 199 | if let Some(expr) = default_expr { |
| 200 | decls.push(format!("let {name} = {expr};")); |
| 201 | name.to_string() |
| 202 | } else { |
| 203 | decls.push(format!("// TODO: provide value for {name} ({comment})")); |
| 204 | "_".to_string() |
| 205 | } |
| 206 | } |
| 207 | } |
| 208 | } |
| 209 | |
| 210 | // Build a tuple literal expression for a list of types, adding declarations for qubits / complex components. |
| 211 | fn build_tuple_literal( |
| 212 | base: &str, |
| 213 | items: &[Ty], |
| 214 | decls: &mut Vec<String>, |
| 215 | qubit_counter: &mut u32, |
| 216 | qubit_reg_counter: &mut u32, |
| 217 | deferred_todos: &mut Vec<String>, |
| 218 | ) -> String { |
| 219 | if items.is_empty() { |
| 220 | return "()".to_string(); |
| 221 | } |
| 222 | let mut parts = Vec::new(); |
| 223 | for ty in items { |
| 224 | match ty { |
| 225 | Ty::Prim(Prim::Qubit) => { |
| 226 | let v = format!("{base}_q{qubit_counter}"); |
| 227 | *qubit_counter += 1; |
| 228 | decls.push(format!("use {v} = Qubit();")); |
| 229 | parts.push(v); |
| 230 | } |
| 231 | Ty::Array(inner) if matches!(**inner, Ty::Prim(Prim::Qubit)) => { |
| 232 | let v = format!("{base}_qs{qubit_reg_counter}"); |
| 233 | *qubit_reg_counter += 1; |
| 234 | decls.push(format!("use {v} = Qubit[1];")); |
| 235 | parts.push(v); |
| 236 | } |
| 237 | Ty::Tuple(sub) => { |
| 238 | let nested = build_tuple_literal( |
| 239 | base, |
| 240 | sub, |
| 241 | decls, |
| 242 | qubit_counter, |
| 243 | qubit_reg_counter, |
| 244 | deferred_todos, |
| 245 | ); |
| 246 | parts.push(nested); |
| 247 | } |
| 248 | _ => { |
| 249 | let (default_expr, comment) = default_value_for_type(ty); |
| 250 | if let Some(expr) = default_expr { |
| 251 | parts.push(expr); |
| 252 | } else { |
| 253 | deferred_todos.push(format!( |
| 254 | "// TODO: provide value for tuple component of {base} ({comment})" |
| 255 | )); |
| 256 | parts.push("_".to_string()); |
| 257 | } |
| 258 | } |
| 259 | } |
| 260 | } |
| 261 | if parts.len() == 1 { |
| 262 | format!("({},)", parts[0]) |
| 263 | } else { |
| 264 | format!("({})", parts.join(", ")) |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | fn default_value_for_type(ty: &Ty) -> (Option<String>, String) { |
| 269 | match ty { |
| 270 | Ty::Prim(p) => match p { |
| 271 | Prim::Int => (Some("0".to_string()), "Int".to_string()), |
| 272 | Prim::Bool => (Some("false".to_string()), "Bool".to_string()), |
| 273 | Prim::Double => (Some("0.0".to_string()), "Double".to_string()), |
| 274 | Prim::Result => (Some("Zero".to_string()), "Result".to_string()), |
| 275 | Prim::Pauli => (Some("PauliI".to_string()), "Pauli".to_string()), |
| 276 | Prim::BigInt => (Some("0L".to_string()), "BigInt".to_string()), |
| 277 | Prim::String => (Some("\"\"".to_string()), "String".to_string()), |
| 278 | Prim::Qubit => (None, "Qubit - allocate with 'use'".to_string()), |
| 279 | Prim::Range | Prim::RangeTo | Prim::RangeFrom | Prim::RangeFull => { |
| 280 | (Some("0..1".to_string()), "Range".to_string()) |
| 281 | } |
| 282 | }, |
| 283 | Ty::Array(_) => (Some("[]".to_string()), "Array".to_string()), |
| 284 | Ty::Tuple(_) => (None, "Tuple".to_string()), |
| 285 | Ty::Param { name, .. } => (None, format!("Generic parameter {name}")), |
| 286 | Ty::Udt(name, _) => (None, format!("UDT {name}")), |
| 287 | Ty::Arrow(_) => (None, "Callable type".to_string()), |
| 288 | Ty::Infer(_) | Ty::Err => (None, "Unknown".to_string()), |
| 289 | } |
| 290 | } |
| 291 | |
| 292 | fn line_indentation(contents: &str, offset: u32) -> String { |
| 293 | let offset_usize = offset as usize; |
| 294 | let line_start = contents[..offset_usize] |
| 295 | .rfind('\n') |
| 296 | .map_or(0, |idx| idx + 1); |
| 297 | contents[line_start..offset_usize] |
| 298 | .chars() |
| 299 | .take_while(|c| *c == ' ' || *c == '\t') |
| 300 | .collect() |
| 301 | } |
| 302 | |
| 303 | // Detect the newline sequence to use when inserting text. Preference order: |
| 304 | // 1. The last newline sequence before the target offset. |
| 305 | // 2. The first newline sequence in the file (if none before offset). |
| 306 | // 3. Fallback to '\n'. Supports '\n' and '\r\n'. |
| 307 | fn detect_newline(contents: &str, op_offset: usize) -> &'static str { |
| 308 | // Helper to examine a position of a '\n' and decide if it's part of a CRLF pair. |
| 309 | let classify = |idx: usize| { |
| 310 | if idx > 0 && contents.as_bytes()[idx - 1] == b'\r' { |
| 311 | "\r\n" |
| 312 | } else { |
| 313 | "\n" |
| 314 | } |
| 315 | }; |
| 316 | |
| 317 | if op_offset <= contents.len() |
| 318 | && let Some(prev_nl) = contents[..op_offset].rfind('\n') |
| 319 | { |
| 320 | return classify(prev_nl); |
| 321 | } |
| 322 | if let Some(first_nl) = contents.find('\n') { |
| 323 | return classify(first_nl); |
| 324 | } |
| 325 | "\n" |
| 326 | } |
| 327 | |