microsoft/TypeAgent
Publicmirrored from https://github.com/microsoft/TypeAgentAvailable
python/ta/test/cmpsearch.py
491lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import argparse |
| 5 | import asyncio |
| 6 | import builtins |
| 7 | from dataclasses import dataclass |
| 8 | import difflib |
| 9 | import json |
| 10 | import re |
| 11 | import sys |
| 12 | from typing import TypedDict |
| 13 | |
| 14 | import black |
| 15 | import colorama |
| 16 | import numpy as np |
| 17 | import typechat |
| 18 | |
| 19 | from typeagent.aitools import utils |
| 20 | from typeagent.aitools.embeddings import AsyncEmbeddingModel |
| 21 | from typeagent.knowpro.answer_response_schema import AnswerResponse |
| 22 | from typeagent.knowpro import answers |
| 23 | from typeagent.knowpro.importing import ConversationSettings |
| 24 | from typeagent.knowpro.convknowledge import create_typechat_model |
| 25 | from typeagent.knowpro.interfaces import ( |
| 26 | IConversation, |
| 27 | ScoredMessageOrdinal, |
| 28 | ScoredSemanticRefOrdinal, |
| 29 | ) |
| 30 | from typeagent.knowpro.search import ConversationSearchResult, SearchQueryExpr |
| 31 | from typeagent.knowpro.search_query_schema import SearchQuery |
| 32 | from typeagent.knowpro.serialization import deserialize_object |
| 33 | from typeagent.knowpro import searchlang |
| 34 | from typeagent.podcasts.podcast import Podcast |
| 35 | |
| 36 | |
| 37 | @dataclass |
| 38 | class Context: |
| 39 | conversation: IConversation |
| 40 | query_translator: typechat.TypeChatJsonTranslator[SearchQuery] |
| 41 | answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse] |
| 42 | embedding_model: AsyncEmbeddingModel |
| 43 | lang_search_options: searchlang.LanguageSearchOptions |
| 44 | answer_options: answers.AnswerContextOptions |
| 45 | interactive: bool |
| 46 | sr_index: dict[str, dict[str, object]] |
| 47 | use_search_query: bool |
| 48 | use_compiled_search_query: bool |
| 49 | |
| 50 | |
| 51 | def main(): |
| 52 | try: |
| 53 | return real_main() |
| 54 | except KeyboardInterrupt: |
| 55 | print("\n") |
| 56 | sys.exit(1) |
| 57 | |
| 58 | |
| 59 | def real_main(): |
| 60 | colorama.init() # So timelog behaves with non-tty stdout. |
| 61 | |
| 62 | # Parse arguments. |
| 63 | |
| 64 | default_qafile = "testdata/Episode_53_Answer_results.json" |
| 65 | default_srfile = "testdata/Episode_53_Search_results.json" |
| 66 | default_podcast_file = "testdata/Episode_53_AdrianTchaikovsky_index" |
| 67 | |
| 68 | explanation = "a list of objects with 'question' and 'answer' keys" |
| 69 | explanation_sr = ( |
| 70 | "a list of objects with 'searchText', 'searchQueryExpr' and 'results' keys" |
| 71 | ) |
| 72 | parser = argparse.ArgumentParser( |
| 73 | description="Run queries from` QAFILE and compare answers to expectations" |
| 74 | ) |
| 75 | parser.add_argument( |
| 76 | "--qafile", |
| 77 | type=str, |
| 78 | default=default_qafile, |
| 79 | help=f"Path to the Answer_results.json file ({explanation})", |
| 80 | ) |
| 81 | parser.add_argument( |
| 82 | "--srfile", |
| 83 | type=str, |
| 84 | default=default_srfile, |
| 85 | help=f"Path to the Search_results.json file ({explanation_sr})", |
| 86 | ) |
| 87 | parser.add_argument( |
| 88 | "--use-search-query", |
| 89 | action="store_true", |
| 90 | default=False, |
| 91 | help="Use search query from SRFILE", |
| 92 | ) |
| 93 | parser.add_argument( |
| 94 | "--use-compiled-search-query", |
| 95 | action="store_true", |
| 96 | default=False, |
| 97 | help="Use compiled search query from SRFILE", |
| 98 | ) |
| 99 | parser.add_argument( |
| 100 | "--podcast", |
| 101 | type=str, |
| 102 | default=default_podcast_file, |
| 103 | help="Path to the podcast index files (excluding the '_index.json' suffix)", |
| 104 | ) |
| 105 | parser.add_argument( |
| 106 | "--offset", |
| 107 | type=int, |
| 108 | default=0, |
| 109 | help="Number of initial Q/A pairs to skip", |
| 110 | ) |
| 111 | parser.add_argument( |
| 112 | "--limit", |
| 113 | type=int, |
| 114 | default=0, |
| 115 | help="Number of Q/A pairs to print (0 means all)", |
| 116 | ) |
| 117 | parser.add_argument( |
| 118 | "--start", |
| 119 | type=int, |
| 120 | default=0, |
| 121 | help="Do just this question (similar to --offset START-1 --limit 1)", |
| 122 | ) |
| 123 | parser.add_argument( |
| 124 | "--interactive", |
| 125 | "-i", |
| 126 | action="store_true", |
| 127 | default=False, |
| 128 | help="Run in interactive mode, waiting for user input before each question", |
| 129 | ) |
| 130 | args = parser.parse_args() |
| 131 | if args.start: |
| 132 | if args.offset != 0: |
| 133 | print("Error: --start and --offset can't be both set", file=sys.stderr) |
| 134 | sys.exit(2) |
| 135 | args.offset = args.start - 1 |
| 136 | if args.limit == 0: |
| 137 | args.limit = 1 |
| 138 | |
| 139 | # Read evaluation data. |
| 140 | |
| 141 | with open(args.qafile, "r") as file: |
| 142 | data = json.load(file) |
| 143 | with open(args.srfile, "r") as file: |
| 144 | srdata = json.load(file) |
| 145 | sr_index = {item["searchText"]: item for item in srdata} |
| 146 | assert isinstance(data, list), "Expected a list of Q/A pairs" |
| 147 | assert len(data) > 0, "Expected non-empty Q/A data" |
| 148 | assert all( |
| 149 | isinstance(qa_pair, dict) and "question" in qa_pair and "answer" in qa_pair |
| 150 | for qa_pair in data |
| 151 | ), "Expected each Q/A pair to be a dict with 'question' and 'answer' keys" |
| 152 | |
| 153 | # Read podcast data. |
| 154 | |
| 155 | utils.load_dotenv() |
| 156 | settings = ConversationSettings() |
| 157 | with utils.timelog("Loading podcast data"): |
| 158 | conversation = Podcast.read_from_file(args.podcast, settings) |
| 159 | assert conversation is not None, f"Failed to load podcast from {file!r}" |
| 160 | |
| 161 | # Create translators. |
| 162 | |
| 163 | model = create_typechat_model() |
| 164 | query_translator = utils.create_translator(model, SearchQuery) |
| 165 | answer_translator = utils.create_translator(model, AnswerResponse) |
| 166 | |
| 167 | # Create context. |
| 168 | |
| 169 | context = Context( |
| 170 | conversation, |
| 171 | query_translator, |
| 172 | answer_translator, |
| 173 | AsyncEmbeddingModel(), |
| 174 | lang_search_options=searchlang.LanguageSearchOptions( |
| 175 | compile_options=searchlang.LanguageQueryCompileOptions( |
| 176 | exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True |
| 177 | ), |
| 178 | exact_match=False, |
| 179 | max_message_matches=25, |
| 180 | ), |
| 181 | answer_options=answers.AnswerContextOptions( |
| 182 | entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None |
| 183 | ), |
| 184 | interactive=args.interactive, |
| 185 | sr_index=sr_index, |
| 186 | use_search_query=args.use_search_query, |
| 187 | use_compiled_search_query=args.use_compiled_search_query, |
| 188 | ) |
| 189 | utils.pretty_print(context.lang_search_options) |
| 190 | utils.pretty_print(context.answer_options) |
| 191 | |
| 192 | # Loop over eval data, skipping duplicate questions |
| 193 | # (Those differ in 'cmd' value, which we don't support yet.) |
| 194 | |
| 195 | offset = args.offset |
| 196 | limit = args.limit |
| 197 | last_q = "" |
| 198 | counter = 0 |
| 199 | all_scores: list[tuple[float, int]] = [] # [(score, counter), ...] |
| 200 | for qa_pair in data: |
| 201 | question = qa_pair.get("question") |
| 202 | answer = qa_pair.get("answer") |
| 203 | if question: |
| 204 | question = question.strip() |
| 205 | if answer: |
| 206 | answer = answer.strip() |
| 207 | if not (question and answer) or question == last_q: |
| 208 | continue |
| 209 | counter += 1 |
| 210 | last_q = question |
| 211 | |
| 212 | # Process offset if specified. |
| 213 | if offset > 0: |
| 214 | offset -= 1 |
| 215 | continue |
| 216 | |
| 217 | # Wait for user input before continuing. |
| 218 | print("-" * 25, counter, "-" * 25) |
| 219 | if context.interactive: |
| 220 | try: |
| 221 | input("Press Enter to continue... ") |
| 222 | except (EOFError, KeyboardInterrupt): |
| 223 | print() |
| 224 | break |
| 225 | |
| 226 | # Compare the given answer with the actual answer for the question. |
| 227 | actual_answer, score = asyncio.run(compare_actual_to_expected(context, qa_pair)) |
| 228 | all_scores.append((score, counter)) |
| 229 | good_enough = score >= 0.97 |
| 230 | print(f"Score: {score:.3f}; Question: {question}", flush=True) |
| 231 | if context.interactive or not good_enough: |
| 232 | cmd = qa_pair.get("cmd") |
| 233 | if cmd and cmd != f'@kpAnswer --query "{question}"': |
| 234 | print(f"Command: {cmd}") |
| 235 | if qa_pair.get("hasNoAnswer"): |
| 236 | answer = f"Failure: {answer}" |
| 237 | print(f"Expected answer:\n{answer}") |
| 238 | print("-" * 20) |
| 239 | print(f"Actual answer:\n{actual_answer}", flush=True) |
| 240 | |
| 241 | # Process limit if specified. |
| 242 | if limit > 0: |
| 243 | limit -= 1 |
| 244 | if limit == 0: |
| 245 | break |
| 246 | |
| 247 | print("=" * 50) |
| 248 | all_scores.sort(reverse=True) |
| 249 | good_scores = [(score, counter) for score, counter in all_scores if score >= 0.97] |
| 250 | bad_scores = [(score, counter) for score, counter in all_scores if score < 0.97] |
| 251 | for label, pairs in [("Good", good_scores), ("Bad", bad_scores)]: |
| 252 | print(f"{label} scores ({len(pairs)}):") |
| 253 | for i in range(0, len(pairs), 10): |
| 254 | print( |
| 255 | ", ".join( |
| 256 | f"{score:.3f}({counter})" for score, counter in pairs[i : i + 10] |
| 257 | ) |
| 258 | ) |
| 259 | |
| 260 | |
| 261 | class RawSearchResult(TypedDict): |
| 262 | messageMatches: list[int] |
| 263 | entityMatches: list[int] |
| 264 | topicMatches: list[int] |
| 265 | actionMatches: list[int] |
| 266 | |
| 267 | |
| 268 | async def compare_actual_to_expected( |
| 269 | context: Context, qa_pair: dict[str, str | None] |
| 270 | ) -> tuple[str | None, float]: |
| 271 | the_answer: str | None = None |
| 272 | score = 0.0 |
| 273 | |
| 274 | question = qa_pair.get("question") |
| 275 | answer = qa_pair.get("answer") |
| 276 | failed = qa_pair.get("hasNoAnswer") |
| 277 | cmd = qa_pair.get("cmd") |
| 278 | if question: |
| 279 | question = question.strip() |
| 280 | if answer: |
| 281 | answer = answer.strip() |
| 282 | if not (question and answer): |
| 283 | return None, score |
| 284 | |
| 285 | if not context.interactive: |
| 286 | log = lambda *args, **kwds: None # Disable debug output in non-interactive mode |
| 287 | else: |
| 288 | log = print |
| 289 | |
| 290 | log() |
| 291 | log("=" * 40) |
| 292 | if cmd: |
| 293 | log(f"Command: {cmd}") |
| 294 | log(f"Question: {question}") |
| 295 | log(f"Answer: {answer}") |
| 296 | log("-" * 40) |
| 297 | |
| 298 | debug_context = searchlang.LanguageSearchDebugContext() |
| 299 | if context.use_search_query: |
| 300 | record = context.sr_index.get(question) |
| 301 | if record: |
| 302 | print("Using search query from SRFILE") |
| 303 | debug_context.use_search_query = deserialize_object( |
| 304 | SearchQuery, record["searchQueryExpr"] |
| 305 | ) |
| 306 | if context.use_compiled_search_query: |
| 307 | record = context.sr_index.get(question) |
| 308 | if record: |
| 309 | print("Using compiled search query from SRFILE") |
| 310 | debug_context.use_compiled_search_query_exprs = deserialize_object( |
| 311 | list[SearchQueryExpr], record["compiledQueryExpr"] |
| 312 | ) |
| 313 | |
| 314 | result = await searchlang.search_conversation_with_language( |
| 315 | context.conversation, |
| 316 | context.query_translator, |
| 317 | question, |
| 318 | context.lang_search_options, |
| 319 | debug_context=debug_context, |
| 320 | ) |
| 321 | log("-" * 40) |
| 322 | if not isinstance(result, typechat.Success): |
| 323 | print("Error:", result.message) |
| 324 | else: |
| 325 | record = context.sr_index.get(question) |
| 326 | if record: |
| 327 | qx = deserialize_object(SearchQuery, record["searchQueryExpr"]) |
| 328 | qc = deserialize_object( |
| 329 | list[searchlang.SearchQueryExpr], record["compiledQueryExpr"] |
| 330 | ) |
| 331 | qr: list[RawSearchResult] = record["results"] |
| 332 | if debug_context.search_query: |
| 333 | compare_and_print_diff( |
| 334 | qx, |
| 335 | debug_context.search_query, |
| 336 | "Search query from LLM does not match reference.", |
| 337 | ) |
| 338 | if debug_context.search_query_expr: |
| 339 | compare_and_print_diff( |
| 340 | qc, |
| 341 | debug_context.search_query_expr, |
| 342 | "Compiled search query expression from LLM does not match reference.", |
| 343 | ) |
| 344 | compare_results( |
| 345 | result.value, qr, "Search results from index do not match reference," |
| 346 | ) |
| 347 | all_answers, combined_answer = await answers.generate_answers( |
| 348 | context.answer_translator, |
| 349 | result.value, |
| 350 | context.conversation, |
| 351 | question, |
| 352 | options=context.answer_options, |
| 353 | ) |
| 354 | log("-" * 40) |
| 355 | if combined_answer.type == "NoAnswer": |
| 356 | # TODO: Compare failure messages. |
| 357 | if failed: |
| 358 | score = 1.0 |
| 359 | the_answer = f"Failure: {combined_answer.whyNoAnswer}" |
| 360 | log(the_answer) |
| 361 | log("All answers:") |
| 362 | if context.interactive: |
| 363 | utils.pretty_print(all_answers) |
| 364 | else: |
| 365 | assert combined_answer.answer is not None, "Expected an answer" |
| 366 | the_answer = combined_answer.answer |
| 367 | if failed: |
| 368 | score = 0.0 |
| 369 | else: |
| 370 | score = await equality_score(context, answer, the_answer) |
| 371 | log(the_answer) |
| 372 | log("Correctness score:", score) |
| 373 | log("=" * 40) |
| 374 | |
| 375 | return the_answer, score |
| 376 | |
| 377 | |
| 378 | async def equality_score(context: Context, a: str, b: str) -> float: |
| 379 | a = a.strip() |
| 380 | b = b.strip() |
| 381 | if a == b: |
| 382 | return 1.0 |
| 383 | if a.lower() == b.lower(): |
| 384 | return 0.999 |
| 385 | embeddings = await context.embedding_model.get_embeddings([a, b]) |
| 386 | assert embeddings.shape[0] == 2, "Expected two embeddings" |
| 387 | return np.dot(embeddings[0], embeddings[1]) |
| 388 | |
| 389 | |
| 390 | def compare_results( |
| 391 | results: list[ConversationSearchResult], |
| 392 | matches_records: list[RawSearchResult], |
| 393 | message: str, |
| 394 | ) -> bool: |
| 395 | if len(results) != len(matches_records): |
| 396 | print( |
| 397 | f"Warning: {message} " |
| 398 | f"(Result sizes mismatch, {len(results)} != {len(matches_records)})" |
| 399 | ) |
| 400 | return False |
| 401 | res = True |
| 402 | for result, record in zip(results, matches_records): |
| 403 | if not compare_message_ordinals( |
| 404 | result.message_matches, record["messageMatches"] |
| 405 | ): |
| 406 | res = False |
| 407 | if not compare_semantic_ref_ordinals( |
| 408 | ( |
| 409 | [] |
| 410 | if "entity" not in result.knowledge_matches |
| 411 | else result.knowledge_matches["entity"].semantic_ref_matches |
| 412 | ), |
| 413 | record.get("entityMatches", []), |
| 414 | "entity", |
| 415 | ): |
| 416 | res = False |
| 417 | if not compare_semantic_ref_ordinals( |
| 418 | ( |
| 419 | [] |
| 420 | if "action" not in result.knowledge_matches |
| 421 | else result.knowledge_matches["action"].semantic_ref_matches |
| 422 | ), |
| 423 | record.get("actionMatches", []), |
| 424 | "action", |
| 425 | ): |
| 426 | res = False |
| 427 | if not compare_semantic_ref_ordinals( |
| 428 | ( |
| 429 | [] |
| 430 | if "topic" not in result.knowledge_matches |
| 431 | else result.knowledge_matches["topic"].semantic_ref_matches |
| 432 | ), |
| 433 | record.get("topicMatches", []), |
| 434 | "topic", |
| 435 | ): |
| 436 | res = False |
| 437 | return res |
| 438 | |
| 439 | |
| 440 | def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bool: |
| 441 | a = [aai.message_ordinal for aai in aa] |
| 442 | if sorted(a) == sorted(b): |
| 443 | return True |
| 444 | print("Message ordinals do not match:") |
| 445 | utils.list_diff(" Expected:", b, " Actual:", a, max_items=20) |
| 446 | return False |
| 447 | |
| 448 | |
| 449 | def compare_semantic_ref_ordinals( |
| 450 | aa: list[ScoredSemanticRefOrdinal], b: list[int], label: str |
| 451 | ) -> bool: |
| 452 | a = [aai.semantic_ref_ordinal for aai in aa] |
| 453 | if sorted(a) == sorted(b): |
| 454 | return True |
| 455 | print(f"{label.capitalize()} SemanticRef ordinals do not match:") |
| 456 | utils.list_diff(" Expected:", b, " Actual:", a, max_items=20) |
| 457 | return False |
| 458 | |
| 459 | |
| 460 | def compare_and_print_diff( |
| 461 | a: object, b: object, message: str |
| 462 | ) -> bool: # True if unequal |
| 463 | """Diff two objects whose repr() is a valid Python expression.""" |
| 464 | if a == b: |
| 465 | return False |
| 466 | a_repr = builtins.repr(a) |
| 467 | b_repr = builtins.repr(b) |
| 468 | if a_repr == b_repr: |
| 469 | return False |
| 470 | # Shorten floats so slight differences in score etc. don't cause false positives. |
| 471 | a_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.1f}", a_repr) |
| 472 | b_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.1f}", b_repr) |
| 473 | if a_repr == b_repr: |
| 474 | return False |
| 475 | print("Warning:", message) |
| 476 | a_formatted = black.format_str(a_repr, mode=black.FileMode()) |
| 477 | b_formatted = black.format_str(b_repr, mode=black.FileMode()) |
| 478 | diff = difflib.unified_diff( |
| 479 | a_formatted.splitlines(True), |
| 480 | b_formatted.splitlines(True), |
| 481 | fromfile="expected", |
| 482 | tofile="actual", |
| 483 | n=2, |
| 484 | ) |
| 485 | for x in diff: |
| 486 | print(x, end="") |
| 487 | return True |
| 488 | |
| 489 | |
| 490 | if __name__ == "__main__": |
| 491 | main() |
| 492 | |