microsoft/TypeAgent

Public

mirrored from https://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2f06d474b6949d65d73ee95868ac29c3f185992e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/test/cmpsearch.py

491lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4import argparse
5import asyncio
6import builtins
7from dataclasses import dataclass
8import difflib
9import json
10import re
11import sys
12from typing import TypedDict
13
14import black
15import colorama
16import numpy as np
17import typechat
18
19from typeagent.aitools import utils
20from typeagent.aitools.embeddings import AsyncEmbeddingModel
21from typeagent.knowpro.answer_response_schema import AnswerResponse
22from typeagent.knowpro import answers
23from typeagent.knowpro.importing import ConversationSettings
24from typeagent.knowpro.convknowledge import create_typechat_model
25from typeagent.knowpro.interfaces import (
26 IConversation,
27 ScoredMessageOrdinal,
28 ScoredSemanticRefOrdinal,
29)
30from typeagent.knowpro.search import ConversationSearchResult, SearchQueryExpr
31from typeagent.knowpro.search_query_schema import SearchQuery
32from typeagent.knowpro.serialization import deserialize_object
33from typeagent.knowpro import searchlang
34from typeagent.podcasts.podcast import Podcast
35
36
37@dataclass
38class Context:
39 conversation: IConversation
40 query_translator: typechat.TypeChatJsonTranslator[SearchQuery]
41 answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse]
42 embedding_model: AsyncEmbeddingModel
43 lang_search_options: searchlang.LanguageSearchOptions
44 answer_options: answers.AnswerContextOptions
45 interactive: bool
46 sr_index: dict[str, dict[str, object]]
47 use_search_query: bool
48 use_compiled_search_query: bool
49
50
51def main():
52 try:
53 return real_main()
54 except KeyboardInterrupt:
55 print("\n")
56 sys.exit(1)
57
58
59def real_main():
60 colorama.init() # So timelog behaves with non-tty stdout.
61
62 # Parse arguments.
63
64 default_qafile = "testdata/Episode_53_Answer_results.json"
65 default_srfile = "testdata/Episode_53_Search_results.json"
66 default_podcast_file = "testdata/Episode_53_AdrianTchaikovsky_index"
67
68 explanation = "a list of objects with 'question' and 'answer' keys"
69 explanation_sr = (
70 "a list of objects with 'searchText', 'searchQueryExpr' and 'results' keys"
71 )
72 parser = argparse.ArgumentParser(
73 description="Run queries from` QAFILE and compare answers to expectations"
74 )
75 parser.add_argument(
76 "--qafile",
77 type=str,
78 default=default_qafile,
79 help=f"Path to the Answer_results.json file ({explanation})",
80 )
81 parser.add_argument(
82 "--srfile",
83 type=str,
84 default=default_srfile,
85 help=f"Path to the Search_results.json file ({explanation_sr})",
86 )
87 parser.add_argument(
88 "--use-search-query",
89 action="store_true",
90 default=False,
91 help="Use search query from SRFILE",
92 )
93 parser.add_argument(
94 "--use-compiled-search-query",
95 action="store_true",
96 default=False,
97 help="Use compiled search query from SRFILE",
98 )
99 parser.add_argument(
100 "--podcast",
101 type=str,
102 default=default_podcast_file,
103 help="Path to the podcast index files (excluding the '_index.json' suffix)",
104 )
105 parser.add_argument(
106 "--offset",
107 type=int,
108 default=0,
109 help="Number of initial Q/A pairs to skip",
110 )
111 parser.add_argument(
112 "--limit",
113 type=int,
114 default=0,
115 help="Number of Q/A pairs to print (0 means all)",
116 )
117 parser.add_argument(
118 "--start",
119 type=int,
120 default=0,
121 help="Do just this question (similar to --offset START-1 --limit 1)",
122 )
123 parser.add_argument(
124 "--interactive",
125 "-i",
126 action="store_true",
127 default=False,
128 help="Run in interactive mode, waiting for user input before each question",
129 )
130 args = parser.parse_args()
131 if args.start:
132 if args.offset != 0:
133 print("Error: --start and --offset can't be both set", file=sys.stderr)
134 sys.exit(2)
135 args.offset = args.start - 1
136 if args.limit == 0:
137 args.limit = 1
138
139 # Read evaluation data.
140
141 with open(args.qafile, "r") as file:
142 data = json.load(file)
143 with open(args.srfile, "r") as file:
144 srdata = json.load(file)
145 sr_index = {item["searchText"]: item for item in srdata}
146 assert isinstance(data, list), "Expected a list of Q/A pairs"
147 assert len(data) > 0, "Expected non-empty Q/A data"
148 assert all(
149 isinstance(qa_pair, dict) and "question" in qa_pair and "answer" in qa_pair
150 for qa_pair in data
151 ), "Expected each Q/A pair to be a dict with 'question' and 'answer' keys"
152
153 # Read podcast data.
154
155 utils.load_dotenv()
156 settings = ConversationSettings()
157 with utils.timelog("Loading podcast data"):
158 conversation = Podcast.read_from_file(args.podcast, settings)
159 assert conversation is not None, f"Failed to load podcast from {file!r}"
160
161 # Create translators.
162
163 model = create_typechat_model()
164 query_translator = utils.create_translator(model, SearchQuery)
165 answer_translator = utils.create_translator(model, AnswerResponse)
166
167 # Create context.
168
169 context = Context(
170 conversation,
171 query_translator,
172 answer_translator,
173 AsyncEmbeddingModel(),
174 lang_search_options=searchlang.LanguageSearchOptions(
175 compile_options=searchlang.LanguageQueryCompileOptions(
176 exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True
177 ),
178 exact_match=False,
179 max_message_matches=25,
180 ),
181 answer_options=answers.AnswerContextOptions(
182 entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
183 ),
184 interactive=args.interactive,
185 sr_index=sr_index,
186 use_search_query=args.use_search_query,
187 use_compiled_search_query=args.use_compiled_search_query,
188 )
189 utils.pretty_print(context.lang_search_options)
190 utils.pretty_print(context.answer_options)
191
192 # Loop over eval data, skipping duplicate questions
193 # (Those differ in 'cmd' value, which we don't support yet.)
194
195 offset = args.offset
196 limit = args.limit
197 last_q = ""
198 counter = 0
199 all_scores: list[tuple[float, int]] = [] # [(score, counter), ...]
200 for qa_pair in data:
201 question = qa_pair.get("question")
202 answer = qa_pair.get("answer")
203 if question:
204 question = question.strip()
205 if answer:
206 answer = answer.strip()
207 if not (question and answer) or question == last_q:
208 continue
209 counter += 1
210 last_q = question
211
212 # Process offset if specified.
213 if offset > 0:
214 offset -= 1
215 continue
216
217 # Wait for user input before continuing.
218 print("-" * 25, counter, "-" * 25)
219 if context.interactive:
220 try:
221 input("Press Enter to continue... ")
222 except (EOFError, KeyboardInterrupt):
223 print()
224 break
225
226 # Compare the given answer with the actual answer for the question.
227 actual_answer, score = asyncio.run(compare_actual_to_expected(context, qa_pair))
228 all_scores.append((score, counter))
229 good_enough = score >= 0.97
230 print(f"Score: {score:.3f}; Question: {question}", flush=True)
231 if context.interactive or not good_enough:
232 cmd = qa_pair.get("cmd")
233 if cmd and cmd != f'@kpAnswer --query "{question}"':
234 print(f"Command: {cmd}")
235 if qa_pair.get("hasNoAnswer"):
236 answer = f"Failure: {answer}"
237 print(f"Expected answer:\n{answer}")
238 print("-" * 20)
239 print(f"Actual answer:\n{actual_answer}", flush=True)
240
241 # Process limit if specified.
242 if limit > 0:
243 limit -= 1
244 if limit == 0:
245 break
246
247 print("=" * 50)
248 all_scores.sort(reverse=True)
249 good_scores = [(score, counter) for score, counter in all_scores if score >= 0.97]
250 bad_scores = [(score, counter) for score, counter in all_scores if score < 0.97]
251 for label, pairs in [("Good", good_scores), ("Bad", bad_scores)]:
252 print(f"{label} scores ({len(pairs)}):")
253 for i in range(0, len(pairs), 10):
254 print(
255 ", ".join(
256 f"{score:.3f}({counter})" for score, counter in pairs[i : i + 10]
257 )
258 )
259
260
261class RawSearchResult(TypedDict):
262 messageMatches: list[int]
263 entityMatches: list[int]
264 topicMatches: list[int]
265 actionMatches: list[int]
266
267
268async def compare_actual_to_expected(
269 context: Context, qa_pair: dict[str, str | None]
270) -> tuple[str | None, float]:
271 the_answer: str | None = None
272 score = 0.0
273
274 question = qa_pair.get("question")
275 answer = qa_pair.get("answer")
276 failed = qa_pair.get("hasNoAnswer")
277 cmd = qa_pair.get("cmd")
278 if question:
279 question = question.strip()
280 if answer:
281 answer = answer.strip()
282 if not (question and answer):
283 return None, score
284
285 if not context.interactive:
286 log = lambda *args, **kwds: None # Disable debug output in non-interactive mode
287 else:
288 log = print
289
290 log()
291 log("=" * 40)
292 if cmd:
293 log(f"Command: {cmd}")
294 log(f"Question: {question}")
295 log(f"Answer: {answer}")
296 log("-" * 40)
297
298 debug_context = searchlang.LanguageSearchDebugContext()
299 if context.use_search_query:
300 record = context.sr_index.get(question)
301 if record:
302 print("Using search query from SRFILE")
303 debug_context.use_search_query = deserialize_object(
304 SearchQuery, record["searchQueryExpr"]
305 )
306 if context.use_compiled_search_query:
307 record = context.sr_index.get(question)
308 if record:
309 print("Using compiled search query from SRFILE")
310 debug_context.use_compiled_search_query_exprs = deserialize_object(
311 list[SearchQueryExpr], record["compiledQueryExpr"]
312 )
313
314 result = await searchlang.search_conversation_with_language(
315 context.conversation,
316 context.query_translator,
317 question,
318 context.lang_search_options,
319 debug_context=debug_context,
320 )
321 log("-" * 40)
322 if not isinstance(result, typechat.Success):
323 print("Error:", result.message)
324 else:
325 record = context.sr_index.get(question)
326 if record:
327 qx = deserialize_object(SearchQuery, record["searchQueryExpr"])
328 qc = deserialize_object(
329 list[searchlang.SearchQueryExpr], record["compiledQueryExpr"]
330 )
331 qr: list[RawSearchResult] = record["results"]
332 if debug_context.search_query:
333 compare_and_print_diff(
334 qx,
335 debug_context.search_query,
336 "Search query from LLM does not match reference.",
337 )
338 if debug_context.search_query_expr:
339 compare_and_print_diff(
340 qc,
341 debug_context.search_query_expr,
342 "Compiled search query expression from LLM does not match reference.",
343 )
344 compare_results(
345 result.value, qr, "Search results from index do not match reference,"
346 )
347 all_answers, combined_answer = await answers.generate_answers(
348 context.answer_translator,
349 result.value,
350 context.conversation,
351 question,
352 options=context.answer_options,
353 )
354 log("-" * 40)
355 if combined_answer.type == "NoAnswer":
356 # TODO: Compare failure messages.
357 if failed:
358 score = 1.0
359 the_answer = f"Failure: {combined_answer.whyNoAnswer}"
360 log(the_answer)
361 log("All answers:")
362 if context.interactive:
363 utils.pretty_print(all_answers)
364 else:
365 assert combined_answer.answer is not None, "Expected an answer"
366 the_answer = combined_answer.answer
367 if failed:
368 score = 0.0
369 else:
370 score = await equality_score(context, answer, the_answer)
371 log(the_answer)
372 log("Correctness score:", score)
373 log("=" * 40)
374
375 return the_answer, score
376
377
378async def equality_score(context: Context, a: str, b: str) -> float:
379 a = a.strip()
380 b = b.strip()
381 if a == b:
382 return 1.0
383 if a.lower() == b.lower():
384 return 0.999
385 embeddings = await context.embedding_model.get_embeddings([a, b])
386 assert embeddings.shape[0] == 2, "Expected two embeddings"
387 return np.dot(embeddings[0], embeddings[1])
388
389
390def compare_results(
391 results: list[ConversationSearchResult],
392 matches_records: list[RawSearchResult],
393 message: str,
394) -> bool:
395 if len(results) != len(matches_records):
396 print(
397 f"Warning: {message} "
398 f"(Result sizes mismatch, {len(results)} != {len(matches_records)})"
399 )
400 return False
401 res = True
402 for result, record in zip(results, matches_records):
403 if not compare_message_ordinals(
404 result.message_matches, record["messageMatches"]
405 ):
406 res = False
407 if not compare_semantic_ref_ordinals(
408 (
409 []
410 if "entity" not in result.knowledge_matches
411 else result.knowledge_matches["entity"].semantic_ref_matches
412 ),
413 record.get("entityMatches", []),
414 "entity",
415 ):
416 res = False
417 if not compare_semantic_ref_ordinals(
418 (
419 []
420 if "action" not in result.knowledge_matches
421 else result.knowledge_matches["action"].semantic_ref_matches
422 ),
423 record.get("actionMatches", []),
424 "action",
425 ):
426 res = False
427 if not compare_semantic_ref_ordinals(
428 (
429 []
430 if "topic" not in result.knowledge_matches
431 else result.knowledge_matches["topic"].semantic_ref_matches
432 ),
433 record.get("topicMatches", []),
434 "topic",
435 ):
436 res = False
437 return res
438
439
440def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bool:
441 a = [aai.message_ordinal for aai in aa]
442 if sorted(a) == sorted(b):
443 return True
444 print("Message ordinals do not match:")
445 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
446 return False
447
448
449def compare_semantic_ref_ordinals(
450 aa: list[ScoredSemanticRefOrdinal], b: list[int], label: str
451) -> bool:
452 a = [aai.semantic_ref_ordinal for aai in aa]
453 if sorted(a) == sorted(b):
454 return True
455 print(f"{label.capitalize()} SemanticRef ordinals do not match:")
456 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
457 return False
458
459
460def compare_and_print_diff(
461 a: object, b: object, message: str
462) -> bool: # True if unequal
463 """Diff two objects whose repr() is a valid Python expression."""
464 if a == b:
465 return False
466 a_repr = builtins.repr(a)
467 b_repr = builtins.repr(b)
468 if a_repr == b_repr:
469 return False
470 # Shorten floats so slight differences in score etc. don't cause false positives.
471 a_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.1f}", a_repr)
472 b_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.1f}", b_repr)
473 if a_repr == b_repr:
474 return False
475 print("Warning:", message)
476 a_formatted = black.format_str(a_repr, mode=black.FileMode())
477 b_formatted = black.format_str(b_repr, mode=black.FileMode())
478 diff = difflib.unified_diff(
479 a_formatted.splitlines(True),
480 b_formatted.splitlines(True),
481 fromfile="expected",
482 tofile="actual",
483 n=2,
484 )
485 for x in diff:
486 print(x, end="")
487 return True
488
489
490if __name__ == "__main__":
491 main()
492