microsoft/TypeAgent
Publicmirrored from https://github.com/microsoft/TypeAgentAvailable
python/ta/test/cmpsearch.py
490lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import argparse |
| 5 | import asyncio |
| 6 | import builtins |
| 7 | from dataclasses import dataclass |
| 8 | import difflib |
| 9 | import json |
| 10 | import re |
| 11 | import sys |
| 12 | from typing import Any, TypedDict |
| 13 | |
| 14 | import black |
| 15 | import colorama |
| 16 | import numpy as np |
| 17 | import typechat |
| 18 | |
| 19 | from typeagent.aitools import utils |
| 20 | from typeagent.aitools.embeddings import AsyncEmbeddingModel |
| 21 | from typeagent.knowpro.answer_response_schema import AnswerResponse |
| 22 | from typeagent.knowpro import answers |
| 23 | from typeagent.knowpro.importing import ConversationSettings |
| 24 | from typeagent.knowpro.convknowledge import create_typechat_model |
| 25 | from typeagent.knowpro.interfaces import ( |
| 26 | IConversation, |
| 27 | ScoredMessageOrdinal, |
| 28 | ScoredSemanticRefOrdinal, |
| 29 | ) |
| 30 | from typeagent.knowpro.search import ConversationSearchResult, SearchQueryExpr |
| 31 | from typeagent.knowpro.search_query_schema import SearchQuery |
| 32 | from typeagent.knowpro.serialization import deserialize_object |
| 33 | from typeagent.knowpro import searchlang |
| 34 | from typeagent.podcasts.podcast import Podcast |
| 35 | |
| 36 | |
| 37 | @dataclass |
| 38 | class Context: |
| 39 | conversation: IConversation |
| 40 | query_translator: typechat.TypeChatJsonTranslator[SearchQuery] |
| 41 | answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse] |
| 42 | embedding_model: AsyncEmbeddingModel |
| 43 | lang_search_options: searchlang.LanguageSearchOptions |
| 44 | answer_options: answers.AnswerContextOptions |
| 45 | interactive: bool |
| 46 | sr_index: dict[str, dict[str, Any]] |
| 47 | use_search_query: bool |
| 48 | use_compiled_search_query: bool |
| 49 | |
| 50 | |
| 51 | def main(): |
| 52 | colorama.init() # So timelog behaves with non-tty stdout. |
| 53 | |
| 54 | # Parse arguments. |
| 55 | |
| 56 | default_qafile = "testdata/Episode_53_Answer_results.json" |
| 57 | default_srfile = "testdata/Episode_53_Search_results.json" |
| 58 | default_podcast_file = "testdata/Episode_53_AdrianTchaikovsky_index" |
| 59 | |
| 60 | explanation = "a list of objects with 'question' and 'answer' keys" |
| 61 | explanation_sr = ( |
| 62 | "a list of objects with 'searchText', 'searchQueryExpr' and 'results' keys" |
| 63 | ) |
| 64 | parser = argparse.ArgumentParser( |
| 65 | description="Run queries from` QAFILE and compare answers to expectations" |
| 66 | ) |
| 67 | parser.add_argument( |
| 68 | "--qafile", |
| 69 | type=str, |
| 70 | default=default_qafile, |
| 71 | help=f"Path to the Answer_results.json file ({explanation})", |
| 72 | ) |
| 73 | parser.add_argument( |
| 74 | "--srfile", |
| 75 | type=str, |
| 76 | default=default_srfile, |
| 77 | help=f"Path to the Search_results.json file ({explanation_sr})", |
| 78 | ) |
| 79 | parser.add_argument( |
| 80 | "--use-search-query", |
| 81 | action="store_true", |
| 82 | default=False, |
| 83 | help="Use search query from SRFILE", |
| 84 | ) |
| 85 | parser.add_argument( |
| 86 | "--use-compiled-search-query", |
| 87 | action="store_true", |
| 88 | default=False, |
| 89 | help="Use compiled search query from SRFILE", |
| 90 | ) |
| 91 | parser.add_argument( |
| 92 | "--podcast", |
| 93 | type=str, |
| 94 | default=default_podcast_file, |
| 95 | help="Path to the podcast index files (excluding the '_index.json' suffix)", |
| 96 | ) |
| 97 | parser.add_argument( |
| 98 | "--offset", |
| 99 | type=int, |
| 100 | default=0, |
| 101 | help="Number of initial Q/A pairs to skip", |
| 102 | ) |
| 103 | parser.add_argument( |
| 104 | "--limit", |
| 105 | type=int, |
| 106 | default=0, |
| 107 | help="Number of Q/A pairs to print (0 means all)", |
| 108 | ) |
| 109 | parser.add_argument( |
| 110 | "--start", |
| 111 | type=int, |
| 112 | default=0, |
| 113 | help="Do just this question (similar to --offset START-1 --limit 1)", |
| 114 | ) |
| 115 | parser.add_argument( |
| 116 | "--interactive", |
| 117 | "-i", |
| 118 | action="store_true", |
| 119 | default=False, |
| 120 | help="Run in interactive mode, waiting for user input before each question", |
| 121 | ) |
| 122 | args = parser.parse_args() |
| 123 | if args.start: |
| 124 | if args.offset != 0: |
| 125 | print("Error: --start and --offset can't be both set", file=sys.stderr) |
| 126 | sys.exit(2) |
| 127 | args.offset = args.start - 1 |
| 128 | if args.limit == 0: |
| 129 | args.limit = 1 |
| 130 | |
| 131 | # Read evaluation data. |
| 132 | |
| 133 | with open(args.qafile, "r") as file: |
| 134 | data = json.load(file) |
| 135 | with open(args.srfile, "r") as file: |
| 136 | srdata = json.load(file) |
| 137 | sr_index = {item["searchText"]: item for item in srdata} |
| 138 | assert isinstance(data, list), "Expected a list of Q/A pairs" |
| 139 | assert len(data) > 0, "Expected non-empty Q/A data" |
| 140 | assert all( |
| 141 | isinstance(qa_pair, dict) and "question" in qa_pair and "answer" in qa_pair |
| 142 | for qa_pair in data |
| 143 | ), "Expected each Q/A pair to be a dict with 'question' and 'answer' keys" |
| 144 | |
| 145 | # Read podcast data. |
| 146 | |
| 147 | utils.load_dotenv() |
| 148 | settings = ConversationSettings() |
| 149 | with utils.timelog("Loading podcast data"): |
| 150 | conversation = Podcast.read_from_file(args.podcast, settings) |
| 151 | assert conversation is not None, f"Failed to load podcast from {file!r}" |
| 152 | |
| 153 | # Create translators. |
| 154 | |
| 155 | model = create_typechat_model() |
| 156 | query_translator = utils.create_translator(model, SearchQuery) |
| 157 | answer_translator = utils.create_translator(model, AnswerResponse) |
| 158 | |
| 159 | # Create context. |
| 160 | |
| 161 | context = Context( |
| 162 | conversation, |
| 163 | query_translator, |
| 164 | answer_translator, |
| 165 | AsyncEmbeddingModel(), |
| 166 | lang_search_options=searchlang.LanguageSearchOptions( |
| 167 | compile_options=searchlang.LanguageQueryCompileOptions( |
| 168 | exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True |
| 169 | ), |
| 170 | exact_match=False, |
| 171 | max_message_matches=25, |
| 172 | ), |
| 173 | answer_options=answers.AnswerContextOptions( |
| 174 | entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None |
| 175 | ), |
| 176 | interactive=args.interactive, |
| 177 | sr_index=sr_index, |
| 178 | use_search_query=args.use_search_query, |
| 179 | use_compiled_search_query=args.use_compiled_search_query, |
| 180 | ) |
| 181 | utils.pretty_print(context.lang_search_options) |
| 182 | utils.pretty_print(context.answer_options) |
| 183 | |
| 184 | # Loop over eval data, skipping duplicate questions |
| 185 | # (Those differ in 'cmd' value, which we don't support yet.) |
| 186 | |
| 187 | offset = args.offset |
| 188 | limit = args.limit |
| 189 | last_q = "" |
| 190 | counter = 0 |
| 191 | all_scores: list[tuple[float, int]] = [] # [(score, counter), ...] |
| 192 | for qa_pair in data: |
| 193 | question = qa_pair.get("question") |
| 194 | answer = qa_pair.get("answer") |
| 195 | if question: |
| 196 | question = question.strip() |
| 197 | if answer: |
| 198 | answer = answer.strip() |
| 199 | if not (question and answer) or question == last_q: |
| 200 | continue |
| 201 | counter += 1 |
| 202 | last_q = question |
| 203 | |
| 204 | # Process offset if specified. |
| 205 | if offset > 0: |
| 206 | offset -= 1 |
| 207 | continue |
| 208 | |
| 209 | # Wait for user input before continuing. |
| 210 | print("-" * 25, counter, "-" * 25) |
| 211 | if context.interactive: |
| 212 | try: |
| 213 | input("Press Enter to continue... ") |
| 214 | except (EOFError, KeyboardInterrupt): |
| 215 | print() |
| 216 | break |
| 217 | |
| 218 | # Compare the given answer with the actual answer for the question. |
| 219 | actual_answer, score = asyncio.run(compare_actual_to_expected(context, qa_pair)) |
| 220 | all_scores.append((score, counter)) |
| 221 | good_enough = score >= 0.97 |
| 222 | print(f"Score: {score:.3f}; Question: {question}", flush=True) |
| 223 | if context.interactive or not good_enough: |
| 224 | cmd = qa_pair.get("cmd") |
| 225 | if cmd and cmd != f'@kpAnswer --query "{question}"': |
| 226 | print(f"Command: {cmd}") |
| 227 | if qa_pair.get("hasNoAnswer"): |
| 228 | answer = f"Failure: {answer}" |
| 229 | print(f"Expected answer:\n{answer}") |
| 230 | print("-" * 20) |
| 231 | print(f"Actual answer:\n{actual_answer}", flush=True) |
| 232 | |
| 233 | # Process limit if specified. |
| 234 | if limit > 0: |
| 235 | limit -= 1 |
| 236 | if limit == 0: |
| 237 | break |
| 238 | |
| 239 | print("=" * 50) |
| 240 | all_scores.sort(reverse=True) |
| 241 | good_scores = [(score, counter) for score, counter in all_scores if score >= 0.97] |
| 242 | bad_scores = [(score, counter) for score, counter in all_scores if score < 0.97] |
| 243 | for label, pairs in [("Good", good_scores), ("Bad", bad_scores)]: |
| 244 | print(f"{label} scores ({len(pairs)}):") |
| 245 | for i in range(0, len(pairs), 10): |
| 246 | print( |
| 247 | ", ".join( |
| 248 | f"{score:.3f}({counter})" for score, counter in pairs[i : i + 10] |
| 249 | ) |
| 250 | ) |
| 251 | |
| 252 | |
| 253 | class RawSearchResult(TypedDict): |
| 254 | messageMatches: list[int] |
| 255 | entityMatches: list[int] |
| 256 | topicMatches: list[int] |
| 257 | actionMatches: list[int] |
| 258 | |
| 259 | |
| 260 | async def compare_actual_to_expected( |
| 261 | context: Context, qa_pair: dict[str, str | None] |
| 262 | ) -> tuple[str | None, float]: |
| 263 | the_answer: str | None = None |
| 264 | score = 0.0 |
| 265 | |
| 266 | question = qa_pair.get("question") |
| 267 | answer = qa_pair.get("answer") |
| 268 | failed = qa_pair.get("hasNoAnswer") |
| 269 | cmd = qa_pair.get("cmd") |
| 270 | if question: |
| 271 | question = question.strip() |
| 272 | if answer: |
| 273 | answer = answer.strip() |
| 274 | if not (question and answer): |
| 275 | return None, score |
| 276 | |
| 277 | if not context.interactive: |
| 278 | log = lambda *args, **kwds: None # Disable debug output in non-interactive mode |
| 279 | else: |
| 280 | log = print |
| 281 | |
| 282 | log() |
| 283 | log("=" * 40) |
| 284 | if cmd: |
| 285 | log(f"Command: {cmd}") |
| 286 | log(f"Question: {question}") |
| 287 | log(f"Answer: {answer}") |
| 288 | log("-" * 40) |
| 289 | |
| 290 | debug_context = searchlang.LanguageSearchDebugContext() |
| 291 | if context.use_search_query or context.use_compiled_search_query: |
| 292 | record = context.sr_index.get(question) |
| 293 | if record: |
| 294 | print("Using search query from SRFILE") |
| 295 | debug_context.use_search_query = deserialize_object( |
| 296 | SearchQuery, record["searchQueryExpr"] |
| 297 | ) |
| 298 | if context.use_compiled_search_query: |
| 299 | print("Using compiled search query from SRFILE") |
| 300 | debug_context.use_compiled_search_query_exprs = deserialize_object( |
| 301 | list[SearchQueryExpr], record["compiledQueryExpr"] |
| 302 | ) |
| 303 | |
| 304 | result = await searchlang.search_conversation_with_language( |
| 305 | context.conversation, |
| 306 | context.query_translator, |
| 307 | question, |
| 308 | context.lang_search_options, |
| 309 | debug_context=debug_context, |
| 310 | ) |
| 311 | log("-" * 40) |
| 312 | if not isinstance(result, typechat.Success): |
| 313 | print("Error:", result.message) |
| 314 | else: |
| 315 | record = context.sr_index.get(question) |
| 316 | if record: |
| 317 | qx = deserialize_object(SearchQuery, record["searchQueryExpr"]) |
| 318 | qc = deserialize_object( |
| 319 | list[searchlang.SearchQueryExpr], record["compiledQueryExpr"] |
| 320 | ) |
| 321 | qr: list[RawSearchResult] = record["results"] |
| 322 | if debug_context.search_query: |
| 323 | compare_and_print_diff( |
| 324 | qx, |
| 325 | debug_context.search_query, |
| 326 | "Search query from LLM does not match reference.", |
| 327 | ) |
| 328 | if debug_context.search_query_expr: |
| 329 | compare_and_print_diff( |
| 330 | qc, |
| 331 | debug_context.search_query_expr, |
| 332 | "Compiled search query expression from LLM does not match reference.", |
| 333 | ) |
| 334 | compare_results( |
| 335 | result.value, qr, "Search results from index do not match reference," |
| 336 | ) |
| 337 | all_answers, combined_answer = await answers.generate_answers( |
| 338 | context.answer_translator, |
| 339 | result.value, |
| 340 | context.conversation, |
| 341 | question, |
| 342 | options=context.answer_options, |
| 343 | ) |
| 344 | log("-" * 40) |
| 345 | if combined_answer.type == "NoAnswer": |
| 346 | # TODO: Compare failure messages. |
| 347 | if failed: |
| 348 | score = 1.0 |
| 349 | the_answer = f"Failure: {combined_answer.whyNoAnswer}" |
| 350 | log(the_answer) |
| 351 | log("All answers:") |
| 352 | if context.interactive: |
| 353 | utils.pretty_print(all_answers) |
| 354 | else: |
| 355 | assert combined_answer.answer is not None, "Expected an answer" |
| 356 | the_answer = combined_answer.answer |
| 357 | if failed: |
| 358 | score = 0.0 |
| 359 | else: |
| 360 | score = await equality_score(context, answer, the_answer) |
| 361 | log(the_answer) |
| 362 | log("Correctness score:", score) |
| 363 | log("=" * 40) |
| 364 | |
| 365 | return the_answer, score |
| 366 | |
| 367 | |
| 368 | async def equality_score(context: Context, a: str, b: str) -> float: |
| 369 | a = a.strip() |
| 370 | b = b.strip() |
| 371 | if a == b: |
| 372 | return 1.0 |
| 373 | if a.lower() == b.lower(): |
| 374 | return 0.999 |
| 375 | embeddings = await context.embedding_model.get_embeddings([a, b]) |
| 376 | assert embeddings.shape[0] == 2, "Expected two embeddings" |
| 377 | return np.dot(embeddings[0], embeddings[1]) |
| 378 | |
| 379 | |
| 380 | def compare_results( |
| 381 | results: list[ConversationSearchResult], |
| 382 | matches_records: list[RawSearchResult], |
| 383 | message: str, |
| 384 | ) -> bool: |
| 385 | if len(results) != len(matches_records): |
| 386 | print( |
| 387 | f"Warning: {message} " |
| 388 | f"(Result sizes mismatch, {len(results)} != {len(matches_records)})" |
| 389 | ) |
| 390 | return False |
| 391 | res = True |
| 392 | for result, record in zip(results, matches_records): |
| 393 | if not compare_message_ordinals( |
| 394 | result.message_matches, record["messageMatches"] |
| 395 | ): |
| 396 | res = False |
| 397 | if not compare_semantic_ref_ordinals( |
| 398 | ( |
| 399 | [] |
| 400 | if "entity" not in result.knowledge_matches |
| 401 | else result.knowledge_matches["entity"].semantic_ref_matches |
| 402 | ), |
| 403 | record.get("entityMatches", []), |
| 404 | "entity", |
| 405 | ): |
| 406 | res = False |
| 407 | if not compare_semantic_ref_ordinals( |
| 408 | ( |
| 409 | [] |
| 410 | if "action" not in result.knowledge_matches |
| 411 | else result.knowledge_matches["action"].semantic_ref_matches |
| 412 | ), |
| 413 | record.get("actionMatches", []), |
| 414 | "action", |
| 415 | ): |
| 416 | res = False |
| 417 | if not compare_semantic_ref_ordinals( |
| 418 | ( |
| 419 | [] |
| 420 | if "topic" not in result.knowledge_matches |
| 421 | else result.knowledge_matches["topic"].semantic_ref_matches |
| 422 | ), |
| 423 | record.get("topicMatches", []), |
| 424 | "topic", |
| 425 | ): |
| 426 | res = False |
| 427 | return res |
| 428 | |
| 429 | |
| 430 | # Special case: In the Podcast, these messages are all Kevin saying "Yeah", |
| 431 | # so if the difference is limited to these, we consider it a match. |
| 432 | NOISE_MESSAGES = frozenset({42, 46, 52, 68, 70}) |
| 433 | |
| 434 | |
| 435 | def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bool: |
| 436 | a = [aai.message_ordinal for aai in aa] |
| 437 | if set(a) ^ set(b) <= NOISE_MESSAGES: |
| 438 | return True |
| 439 | print("Message ordinals do not match:") |
| 440 | utils.list_diff(" Expected:", b, " Actual:", a, max_items=20) |
| 441 | return False |
| 442 | |
| 443 | |
| 444 | def compare_semantic_ref_ordinals( |
| 445 | aa: list[ScoredSemanticRefOrdinal], b: list[int], label: str |
| 446 | ) -> bool: |
| 447 | a = [aai.semantic_ref_ordinal for aai in aa] |
| 448 | if sorted(a) == sorted(b): |
| 449 | return True |
| 450 | print(f"{label.capitalize()} SemanticRef ordinals do not match:") |
| 451 | utils.list_diff(" Expected:", b, " Actual:", a, max_items=20) |
| 452 | return False |
| 453 | |
| 454 | |
| 455 | def compare_and_print_diff( |
| 456 | a: object, b: object, message: str |
| 457 | ) -> bool: # True if unequal |
| 458 | """Diff two objects whose repr() is a valid Python expression.""" |
| 459 | if a == b: |
| 460 | return False |
| 461 | a_repr = builtins.repr(a) |
| 462 | b_repr = builtins.repr(b) |
| 463 | if a_repr == b_repr: |
| 464 | return False |
| 465 | # Shorten floats so slight differences in score etc. don't cause false positives. |
| 466 | a_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", a_repr) |
| 467 | b_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", b_repr) |
| 468 | if a_repr == b_repr: |
| 469 | return False |
| 470 | print("Warning:", message) |
| 471 | a_formatted = black.format_str(a_repr, mode=black.FileMode()) |
| 472 | b_formatted = black.format_str(b_repr, mode=black.FileMode()) |
| 473 | diff = difflib.unified_diff( |
| 474 | a_formatted.splitlines(True), |
| 475 | b_formatted.splitlines(True), |
| 476 | fromfile="expected", |
| 477 | tofile="actual", |
| 478 | n=2, |
| 479 | ) |
| 480 | for x in diff: |
| 481 | print(x, end="") |
| 482 | return True |
| 483 | |
| 484 | |
| 485 | if __name__ == "__main__": |
| 486 | try: |
| 487 | main() |
| 488 | except KeyboardInterrupt: |
| 489 | print() |
| 490 | sys.exit(1) |