microsoft/TypeAgent

Public

mirrored from https://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
330c2dd62a6eb315f45da14ab4ce3987184d9860

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/test/cmpsearch.py

490lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4import argparse
5import asyncio
6import builtins
7from dataclasses import dataclass
8import difflib
9import json
10import re
11import sys
12from typing import Any, TypedDict
13
14import black
15import colorama
16import numpy as np
17import typechat
18
19from typeagent.aitools import utils
20from typeagent.aitools.embeddings import AsyncEmbeddingModel
21from typeagent.knowpro.answer_response_schema import AnswerResponse
22from typeagent.knowpro import answers
23from typeagent.knowpro.importing import ConversationSettings
24from typeagent.knowpro.convknowledge import create_typechat_model
25from typeagent.knowpro.interfaces import (
26 IConversation,
27 ScoredMessageOrdinal,
28 ScoredSemanticRefOrdinal,
29)
30from typeagent.knowpro.search import ConversationSearchResult, SearchQueryExpr
31from typeagent.knowpro.search_query_schema import SearchQuery
32from typeagent.knowpro.serialization import deserialize_object
33from typeagent.knowpro import searchlang
34from typeagent.podcasts.podcast import Podcast
35
36
37@dataclass
38class Context:
39 conversation: IConversation
40 query_translator: typechat.TypeChatJsonTranslator[SearchQuery]
41 answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse]
42 embedding_model: AsyncEmbeddingModel
43 lang_search_options: searchlang.LanguageSearchOptions
44 answer_options: answers.AnswerContextOptions
45 interactive: bool
46 sr_index: dict[str, dict[str, Any]]
47 use_search_query: bool
48 use_compiled_search_query: bool
49
50
51def main():
52 colorama.init() # So timelog behaves with non-tty stdout.
53
54 # Parse arguments.
55
56 default_qafile = "testdata/Episode_53_Answer_results.json"
57 default_srfile = "testdata/Episode_53_Search_results.json"
58 default_podcast_file = "testdata/Episode_53_AdrianTchaikovsky_index"
59
60 explanation = "a list of objects with 'question' and 'answer' keys"
61 explanation_sr = (
62 "a list of objects with 'searchText', 'searchQueryExpr' and 'results' keys"
63 )
64 parser = argparse.ArgumentParser(
65 description="Run queries from` QAFILE and compare answers to expectations"
66 )
67 parser.add_argument(
68 "--qafile",
69 type=str,
70 default=default_qafile,
71 help=f"Path to the Answer_results.json file ({explanation})",
72 )
73 parser.add_argument(
74 "--srfile",
75 type=str,
76 default=default_srfile,
77 help=f"Path to the Search_results.json file ({explanation_sr})",
78 )
79 parser.add_argument(
80 "--use-search-query",
81 action="store_true",
82 default=False,
83 help="Use search query from SRFILE",
84 )
85 parser.add_argument(
86 "--use-compiled-search-query",
87 action="store_true",
88 default=False,
89 help="Use compiled search query from SRFILE",
90 )
91 parser.add_argument(
92 "--podcast",
93 type=str,
94 default=default_podcast_file,
95 help="Path to the podcast index files (excluding the '_index.json' suffix)",
96 )
97 parser.add_argument(
98 "--offset",
99 type=int,
100 default=0,
101 help="Number of initial Q/A pairs to skip",
102 )
103 parser.add_argument(
104 "--limit",
105 type=int,
106 default=0,
107 help="Number of Q/A pairs to print (0 means all)",
108 )
109 parser.add_argument(
110 "--start",
111 type=int,
112 default=0,
113 help="Do just this question (similar to --offset START-1 --limit 1)",
114 )
115 parser.add_argument(
116 "--interactive",
117 "-i",
118 action="store_true",
119 default=False,
120 help="Run in interactive mode, waiting for user input before each question",
121 )
122 args = parser.parse_args()
123 if args.start:
124 if args.offset != 0:
125 print("Error: --start and --offset can't be both set", file=sys.stderr)
126 sys.exit(2)
127 args.offset = args.start - 1
128 if args.limit == 0:
129 args.limit = 1
130
131 # Read evaluation data.
132
133 with open(args.qafile, "r") as file:
134 data = json.load(file)
135 with open(args.srfile, "r") as file:
136 srdata = json.load(file)
137 sr_index = {item["searchText"]: item for item in srdata}
138 assert isinstance(data, list), "Expected a list of Q/A pairs"
139 assert len(data) > 0, "Expected non-empty Q/A data"
140 assert all(
141 isinstance(qa_pair, dict) and "question" in qa_pair and "answer" in qa_pair
142 for qa_pair in data
143 ), "Expected each Q/A pair to be a dict with 'question' and 'answer' keys"
144
145 # Read podcast data.
146
147 utils.load_dotenv()
148 settings = ConversationSettings()
149 with utils.timelog("Loading podcast data"):
150 conversation = Podcast.read_from_file(args.podcast, settings)
151 assert conversation is not None, f"Failed to load podcast from {file!r}"
152
153 # Create translators.
154
155 model = create_typechat_model()
156 query_translator = utils.create_translator(model, SearchQuery)
157 answer_translator = utils.create_translator(model, AnswerResponse)
158
159 # Create context.
160
161 context = Context(
162 conversation,
163 query_translator,
164 answer_translator,
165 AsyncEmbeddingModel(),
166 lang_search_options=searchlang.LanguageSearchOptions(
167 compile_options=searchlang.LanguageQueryCompileOptions(
168 exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True
169 ),
170 exact_match=False,
171 max_message_matches=25,
172 ),
173 answer_options=answers.AnswerContextOptions(
174 entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
175 ),
176 interactive=args.interactive,
177 sr_index=sr_index,
178 use_search_query=args.use_search_query,
179 use_compiled_search_query=args.use_compiled_search_query,
180 )
181 utils.pretty_print(context.lang_search_options)
182 utils.pretty_print(context.answer_options)
183
184 # Loop over eval data, skipping duplicate questions
185 # (Those differ in 'cmd' value, which we don't support yet.)
186
187 offset = args.offset
188 limit = args.limit
189 last_q = ""
190 counter = 0
191 all_scores: list[tuple[float, int]] = [] # [(score, counter), ...]
192 for qa_pair in data:
193 question = qa_pair.get("question")
194 answer = qa_pair.get("answer")
195 if question:
196 question = question.strip()
197 if answer:
198 answer = answer.strip()
199 if not (question and answer) or question == last_q:
200 continue
201 counter += 1
202 last_q = question
203
204 # Process offset if specified.
205 if offset > 0:
206 offset -= 1
207 continue
208
209 # Wait for user input before continuing.
210 print("-" * 25, counter, "-" * 25)
211 if context.interactive:
212 try:
213 input("Press Enter to continue... ")
214 except (EOFError, KeyboardInterrupt):
215 print()
216 break
217
218 # Compare the given answer with the actual answer for the question.
219 actual_answer, score = asyncio.run(compare_actual_to_expected(context, qa_pair))
220 all_scores.append((score, counter))
221 good_enough = score >= 0.97
222 print(f"Score: {score:.3f}; Question: {question}", flush=True)
223 if context.interactive or not good_enough:
224 cmd = qa_pair.get("cmd")
225 if cmd and cmd != f'@kpAnswer --query "{question}"':
226 print(f"Command: {cmd}")
227 if qa_pair.get("hasNoAnswer"):
228 answer = f"Failure: {answer}"
229 print(f"Expected answer:\n{answer}")
230 print("-" * 20)
231 print(f"Actual answer:\n{actual_answer}", flush=True)
232
233 # Process limit if specified.
234 if limit > 0:
235 limit -= 1
236 if limit == 0:
237 break
238
239 print("=" * 50)
240 all_scores.sort(reverse=True)
241 good_scores = [(score, counter) for score, counter in all_scores if score >= 0.97]
242 bad_scores = [(score, counter) for score, counter in all_scores if score < 0.97]
243 for label, pairs in [("Good", good_scores), ("Bad", bad_scores)]:
244 print(f"{label} scores ({len(pairs)}):")
245 for i in range(0, len(pairs), 10):
246 print(
247 ", ".join(
248 f"{score:.3f}({counter})" for score, counter in pairs[i : i + 10]
249 )
250 )
251
252
253class RawSearchResult(TypedDict):
254 messageMatches: list[int]
255 entityMatches: list[int]
256 topicMatches: list[int]
257 actionMatches: list[int]
258
259
260async def compare_actual_to_expected(
261 context: Context, qa_pair: dict[str, str | None]
262) -> tuple[str | None, float]:
263 the_answer: str | None = None
264 score = 0.0
265
266 question = qa_pair.get("question")
267 answer = qa_pair.get("answer")
268 failed = qa_pair.get("hasNoAnswer")
269 cmd = qa_pair.get("cmd")
270 if question:
271 question = question.strip()
272 if answer:
273 answer = answer.strip()
274 if not (question and answer):
275 return None, score
276
277 if not context.interactive:
278 log = lambda *args, **kwds: None # Disable debug output in non-interactive mode
279 else:
280 log = print
281
282 log()
283 log("=" * 40)
284 if cmd:
285 log(f"Command: {cmd}")
286 log(f"Question: {question}")
287 log(f"Answer: {answer}")
288 log("-" * 40)
289
290 debug_context = searchlang.LanguageSearchDebugContext()
291 if context.use_search_query or context.use_compiled_search_query:
292 record = context.sr_index.get(question)
293 if record:
294 print("Using search query from SRFILE")
295 debug_context.use_search_query = deserialize_object(
296 SearchQuery, record["searchQueryExpr"]
297 )
298 if context.use_compiled_search_query:
299 print("Using compiled search query from SRFILE")
300 debug_context.use_compiled_search_query_exprs = deserialize_object(
301 list[SearchQueryExpr], record["compiledQueryExpr"]
302 )
303
304 result = await searchlang.search_conversation_with_language(
305 context.conversation,
306 context.query_translator,
307 question,
308 context.lang_search_options,
309 debug_context=debug_context,
310 )
311 log("-" * 40)
312 if not isinstance(result, typechat.Success):
313 print("Error:", result.message)
314 else:
315 record = context.sr_index.get(question)
316 if record:
317 qx = deserialize_object(SearchQuery, record["searchQueryExpr"])
318 qc = deserialize_object(
319 list[searchlang.SearchQueryExpr], record["compiledQueryExpr"]
320 )
321 qr: list[RawSearchResult] = record["results"]
322 if debug_context.search_query:
323 compare_and_print_diff(
324 qx,
325 debug_context.search_query,
326 "Search query from LLM does not match reference.",
327 )
328 if debug_context.search_query_expr:
329 compare_and_print_diff(
330 qc,
331 debug_context.search_query_expr,
332 "Compiled search query expression from LLM does not match reference.",
333 )
334 compare_results(
335 result.value, qr, "Search results from index do not match reference,"
336 )
337 all_answers, combined_answer = await answers.generate_answers(
338 context.answer_translator,
339 result.value,
340 context.conversation,
341 question,
342 options=context.answer_options,
343 )
344 log("-" * 40)
345 if combined_answer.type == "NoAnswer":
346 # TODO: Compare failure messages.
347 if failed:
348 score = 1.0
349 the_answer = f"Failure: {combined_answer.whyNoAnswer}"
350 log(the_answer)
351 log("All answers:")
352 if context.interactive:
353 utils.pretty_print(all_answers)
354 else:
355 assert combined_answer.answer is not None, "Expected an answer"
356 the_answer = combined_answer.answer
357 if failed:
358 score = 0.0
359 else:
360 score = await equality_score(context, answer, the_answer)
361 log(the_answer)
362 log("Correctness score:", score)
363 log("=" * 40)
364
365 return the_answer, score
366
367
368async def equality_score(context: Context, a: str, b: str) -> float:
369 a = a.strip()
370 b = b.strip()
371 if a == b:
372 return 1.0
373 if a.lower() == b.lower():
374 return 0.999
375 embeddings = await context.embedding_model.get_embeddings([a, b])
376 assert embeddings.shape[0] == 2, "Expected two embeddings"
377 return np.dot(embeddings[0], embeddings[1])
378
379
380def compare_results(
381 results: list[ConversationSearchResult],
382 matches_records: list[RawSearchResult],
383 message: str,
384) -> bool:
385 if len(results) != len(matches_records):
386 print(
387 f"Warning: {message} "
388 f"(Result sizes mismatch, {len(results)} != {len(matches_records)})"
389 )
390 return False
391 res = True
392 for result, record in zip(results, matches_records):
393 if not compare_message_ordinals(
394 result.message_matches, record["messageMatches"]
395 ):
396 res = False
397 if not compare_semantic_ref_ordinals(
398 (
399 []
400 if "entity" not in result.knowledge_matches
401 else result.knowledge_matches["entity"].semantic_ref_matches
402 ),
403 record.get("entityMatches", []),
404 "entity",
405 ):
406 res = False
407 if not compare_semantic_ref_ordinals(
408 (
409 []
410 if "action" not in result.knowledge_matches
411 else result.knowledge_matches["action"].semantic_ref_matches
412 ),
413 record.get("actionMatches", []),
414 "action",
415 ):
416 res = False
417 if not compare_semantic_ref_ordinals(
418 (
419 []
420 if "topic" not in result.knowledge_matches
421 else result.knowledge_matches["topic"].semantic_ref_matches
422 ),
423 record.get("topicMatches", []),
424 "topic",
425 ):
426 res = False
427 return res
428
429
430# Special case: In the Podcast, these messages are all Kevin saying "Yeah",
431# so if the difference is limited to these, we consider it a match.
432NOISE_MESSAGES = frozenset({42, 46, 52, 68, 70})
433
434
435def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bool:
436 a = [aai.message_ordinal for aai in aa]
437 if set(a) ^ set(b) <= NOISE_MESSAGES:
438 return True
439 print("Message ordinals do not match:")
440 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
441 return False
442
443
444def compare_semantic_ref_ordinals(
445 aa: list[ScoredSemanticRefOrdinal], b: list[int], label: str
446) -> bool:
447 a = [aai.semantic_ref_ordinal for aai in aa]
448 if sorted(a) == sorted(b):
449 return True
450 print(f"{label.capitalize()} SemanticRef ordinals do not match:")
451 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
452 return False
453
454
455def compare_and_print_diff(
456 a: object, b: object, message: str
457) -> bool: # True if unequal
458 """Diff two objects whose repr() is a valid Python expression."""
459 if a == b:
460 return False
461 a_repr = builtins.repr(a)
462 b_repr = builtins.repr(b)
463 if a_repr == b_repr:
464 return False
465 # Shorten floats so slight differences in score etc. don't cause false positives.
466 a_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", a_repr)
467 b_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", b_repr)
468 if a_repr == b_repr:
469 return False
470 print("Warning:", message)
471 a_formatted = black.format_str(a_repr, mode=black.FileMode())
472 b_formatted = black.format_str(b_repr, mode=black.FileMode())
473 diff = difflib.unified_diff(
474 a_formatted.splitlines(True),
475 b_formatted.splitlines(True),
476 fromfile="expected",
477 tofile="actual",
478 n=2,
479 )
480 for x in diff:
481 print(x, end="")
482 return True
483
484
485if __name__ == "__main__":
486 try:
487 main()
488 except KeyboardInterrupt:
489 print()
490 sys.exit(1)