microsoft/TypeAgent

Public

mirrored from https://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2aa6fcaa0cea62f72e1da620f1c09c5069393ec0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/tools/utool.py

796lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4__version__ = "0.2"
5
6### Imports ###
7
8import argparse
9import asyncio
10from collections.abc import Mapping
11from dataclasses import dataclass
12import difflib
13import json
14import re
15import shutil
16import sys
17import typing
18
19from colorama import init as colorama_init, Fore
20import numpy as np
21
22try:
23 import readline
24except ImportError:
25 readline = None
26import typechat
27
28from typeagent.aitools import embeddings
29from typeagent.aitools import utils
30
31from typeagent.knowpro.interfaces import (
32 IConversation,
33 IMessage,
34 ITermToSemanticRefIndex,
35 ScoredMessageOrdinal,
36 ScoredSemanticRefOrdinal,
37 SemanticRef,
38 Topic,
39)
40from typeagent.knowpro import answers, answer_response_schema
41from typeagent.knowpro import convknowledge
42from typeagent.knowpro import importing
43from typeagent.knowpro import kplib
44from typeagent.knowpro import query
45from typeagent.knowpro import search, search_query_schema, searchlang
46from typeagent.knowpro import serialization
47
48from typeagent.podcasts import podcast
49
50
51### Classes ###
52
53
54class QuestionAnswerData(typing.TypedDict):
55 question: str
56 answer: str
57 hasNoAnswer: bool
58 cmd: str
59
60
61class RawSearchResultData(typing.TypedDict):
62 messageMatches: list[int]
63 entityMatches: list[int]
64 topicMatches: list[int]
65 actionMatches: list[int]
66
67
68class SearchResultData(typing.TypedDict):
69 searchText: str
70 searchQueryExpr: dict[str, typing.Any] # Serialized search_query_schema.SearchQuery
71 compiledQueryExpr: list[dict[str, typing.Any]] # list[search.SearchQueryExpr]
72 results: list[RawSearchResultData]
73
74
75@dataclass
76class ProcessingContext:
77 query_context: query.QueryEvalContext
78 ar_list: list[QuestionAnswerData]
79 sr_list: list[SearchResultData]
80 ar_index: dict[str, QuestionAnswerData]
81 sr_index: dict[str, SearchResultData]
82 debug1: typing.Literal["none", "diff", "full", "skip"]
83 debug2: typing.Literal["none", "diff", "full", "skip"]
84 debug3: typing.Literal["none", "diff", "full", "nice"]
85 debug4: typing.Literal["none", "diff", "full", "nice"]
86 embedding_model: embeddings.AsyncEmbeddingModel
87 query_translator: typechat.TypeChatJsonTranslator[search_query_schema.SearchQuery]
88 answer_translator: typechat.TypeChatJsonTranslator[
89 answer_response_schema.AnswerResponse
90 ]
91 lang_search_options: searchlang.LanguageSearchOptions
92 answer_context_options: answers.AnswerContextOptions
93
94 def __repr__(self) -> str:
95 parts = []
96 parts.append(f"ar_list={len(self.ar_list)}")
97 parts.append(f"sr_list={len(self.sr_list)}")
98 parts.append(f"ar_index={len(self.ar_index)}")
99 parts.append(f"sr_index={len(self.sr_index)}")
100 parts.append(f"debug1={self.debug1}")
101 parts.append(f"debug2={self.debug2}")
102 parts.append(f"debug3={self.debug3}")
103 parts.append(f"debug4={self.debug4}")
104 parts.append(f"lang_search_options={self.lang_search_options}")
105 parts.append(f"answer_context_options={self.answer_context_options}")
106 return f"Context({', '.join(parts)})"
107
108
109### Main logic ###
110
111
112def main():
113 utils.load_dotenv()
114 colorama_init(autoreset=True)
115
116 parser = make_arg_parser("TypeAgent Query Tool")
117 args = parser.parse_args()
118 fill_in_debug_defaults(parser, args)
119 query_context = load_podcast_index(args.podcast)
120 ar_list, ar_index = load_index_file(args.qafile, "question", QuestionAnswerData)
121 sr_list, sr_index = load_index_file(args.srfile, "searchText", SearchResultData)
122
123 model = convknowledge.create_typechat_model()
124 query_translator = utils.create_translator(model, search_query_schema.SearchQuery)
125 if args.alt_schema:
126 print(f"Substituting alt schema from {args.alt_schema}")
127 with open(args.alt_schema) as f:
128 query_translator.schema_str = f.read()
129 if args.show_schema:
130 print(Fore.YELLOW + query_translator.schema_str.rstrip() + Fore.RESET)
131
132 answer_translator = utils.create_translator(
133 model, answer_response_schema.AnswerResponse
134 )
135
136 context = ProcessingContext(
137 query_context,
138 ar_list,
139 sr_list,
140 ar_index,
141 sr_index,
142 args.debug1,
143 args.debug2,
144 args.debug3,
145 args.debug4,
146 embeddings.AsyncEmbeddingModel(),
147 query_translator,
148 answer_translator,
149 searchlang.LanguageSearchOptions(
150 compile_options=searchlang.LanguageQueryCompileOptions(
151 exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True
152 ),
153 exact_match=False,
154 max_message_matches=25,
155 ),
156 answers.AnswerContextOptions(
157 entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
158 ),
159 )
160
161 utils.pretty_print(context, Fore.BLUE, Fore.RESET)
162
163 if args.batch:
164 print(
165 Fore.YELLOW
166 + f"Running in batch mode [{args.offset}:{args.offset + args.limit if args.limit else ''}]."
167 + Fore.RESET
168 )
169 asyncio.run(batch_loop(context, args.offset, args.limit))
170 else:
171 print(Fore.YELLOW + "Running in interactive mode." + Fore.RESET)
172 interactive_loop(context)
173
174
175async def batch_loop(context: ProcessingContext, offset: int, limit: int) -> None:
176 if limit == 0:
177 limit = len(context.ar_list) - offset
178 sublist = context.ar_list[offset : offset + limit]
179 all_scores = []
180 for counter, qadata in enumerate(sublist, offset + 1):
181 question = qadata["question"]
182 print("-" * 20, counter, question, "-" * 20)
183 score = await process_query(context, question)
184 if score is not None:
185 all_scores.append((score, counter))
186 if not all_scores:
187 return
188 print("=" * 50)
189 all_scores.sort(reverse=True)
190 good_scores = [(score, counter) for score, counter in all_scores if score >= 0.97]
191 bad_scores = [(score, counter) for score, counter in all_scores if score < 0.97]
192 for label, pairs in [("Good", good_scores), ("Bad", bad_scores)]:
193 print(f"{label} scores ({len(pairs)}):")
194 for i in range(0, len(pairs), 10):
195 print(
196 ", ".join(
197 f"{score:.3f}({counter})" for score, counter in pairs[i : i + 10]
198 )
199 )
200
201
202def interactive_loop(context: ProcessingContext) -> None:
203 if not sys.stdin.isatty():
204 for line in sys.stdin:
205 line = line.strip()
206 if not line:
207 continue
208 asyncio.run(process_query(context, line))
209 return
210
211 print(f"TypeAgent demo UI {__version__} (type 'q' to exit)")
212 if readline:
213 try:
214 readline.read_history_file(".ui_history")
215 except FileNotFoundError:
216 pass # Ignore if history file does not exist.
217
218 try:
219 while True:
220 try:
221 line = input("TypeAgent> ").strip()
222 except EOFError:
223 print()
224 break
225 if not line:
226 continue
227 if line.lower() in ("exit", "quit", "q"):
228 if readline:
229 readline.remove_history_item(
230 readline.get_current_history_length() - 1
231 )
232 break
233 prsep()
234 asyncio.run(process_query(context, line))
235
236 finally:
237 if readline:
238 readline.write_history_file(".ui_history")
239
240
241### Query processing logic ###
242
243
244async def process_query(context: ProcessingContext, query_text: str) -> float | None:
245 record = context.sr_index.get(query_text)
246 debug_context = searchlang.LanguageSearchDebugContext()
247 if context.debug1 == "skip" or context.debug2 == "skip":
248 if not record or (
249 "searchQueryExpr" not in record or "compiledQueryExpr" not in record
250 ):
251 print("Can't skip stages 1 or 2, no precomputed outcomes found.")
252 else:
253 # Skipping stage 2 implies skipping stage 1, and we must supply the
254 # precomputed results for both stages.
255 debug_context.use_search_query = serialization.deserialize_object(
256 search_query_schema.SearchQuery, record["searchQueryExpr"]
257 )
258 print("Skipping stage 1, substituting precomputed search query.")
259 if context.debug2 == "skip":
260 debug_context.use_compiled_search_query_exprs = (
261 serialization.deserialize_object(
262 list[search.SearchQueryExpr],
263 record["compiledQueryExpr"],
264 )
265 )
266 print(
267 "Skipping stage 2, substituting precomputed compiled query expressions."
268 )
269 prsep()
270
271 result = await searchlang.search_conversation_with_language(
272 context.query_context.conversation,
273 context.query_translator,
274 query_text,
275 context.lang_search_options,
276 debug_context=debug_context,
277 )
278 if isinstance(result, typechat.Failure):
279 print("Stages 1-3 failed:")
280 print(Fore.RED + str(result) + Fore.RESET)
281 return
282 search_results = result.value
283
284 actual1 = debug_context.search_query
285 if actual1:
286 if context.debug1 == "full":
287 print("Stage 1 results:")
288 utils.pretty_print(actual1, Fore.GREEN, Fore.RESET)
289 prsep()
290 elif context.debug1 == "diff":
291 if record and "searchQueryExpr" in record:
292 print("Stage 1 diff:")
293 expected1 = serialization.deserialize_object(
294 search_query_schema.SearchQuery, record["searchQueryExpr"]
295 )
296 compare_and_print_diff(expected1, actual1)
297 else:
298 print("Stage 1 diff unavailable")
299 prsep()
300
301 actual2 = debug_context.search_query_expr
302 if actual2:
303 if context.debug2 == "full":
304 print("Stage 2 results:")
305 utils.pretty_print(actual2, Fore.GREEN, Fore.RESET)
306 prsep()
307 elif context.debug2 == "diff":
308 if record and "compiledQueryExpr" in record:
309 print("Stage 2 diff:")
310 expected2 = serialization.deserialize_object(
311 list[search.SearchQueryExpr], record["compiledQueryExpr"]
312 )
313 compare_and_print_diff(expected2, actual2)
314 else:
315 print("Stage 2 diff unavailable")
316 prsep()
317
318 actual3 = search_results
319 if context.debug3 == "full":
320 print("Stage 3 full results:")
321 utils.pretty_print(actual3, Fore.GREEN, Fore.RESET)
322 prsep()
323 elif context.debug3 == "nice":
324 print("Stage 3 nice results:")
325 for sr in search_results:
326 print_result(sr, context.query_context.conversation)
327 prsep()
328 elif context.debug3 == "diff":
329 if record and "results" in record:
330 print("Stage 3 diff:")
331 expected3: list[RawSearchResultData] = record["results"]
332 compare_results(expected3, actual3)
333 else:
334 print("Stage 3 diff unavailable")
335 prsep()
336
337 all_answers, combined_answer = await answers.generate_answers(
338 context.answer_translator,
339 search_results,
340 context.query_context.conversation,
341 query_text,
342 options=context.answer_context_options,
343 )
344
345 if context.debug4 == "full":
346 utils.pretty_print(all_answers)
347 prsep()
348 if context.debug4 in ("full", "nice"):
349 if combined_answer.type == "NoAnswer":
350 print(Fore.RED + f"Failure: {combined_answer.whyNoAnswer}" + Fore.RESET)
351 else:
352 print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET)
353 prsep()
354 elif context.debug4 == "diff":
355 if query_text in context.ar_index:
356 record = context.ar_index[query_text]
357 expected4: tuple[str, bool] = (record["answer"], not record["hasNoAnswer"])
358 print("Stage 4 diff:")
359 match combined_answer.type:
360 case "NoAnswer":
361 actual4 = (combined_answer.whyNoAnswer or "", False)
362 case "Answered":
363 actual4 = (combined_answer.answer or "", True)
364 score = await compare_answers(context, expected4, actual4)
365 print(f"Score: {score:.3f}; Question: {query_text}")
366 return score
367 else:
368 print("Stage 4 diff unavailable; nice answer:")
369 if combined_answer.type == "NoAnswer":
370 print(Fore.RED + f"Failure: {combined_answer.whyNoAnswer}" + Fore.RESET)
371 else:
372 print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET)
373 prsep()
374
375
376def prsep():
377 print("-" * 50)
378
379
380### CLI processing ###
381
382
383def make_arg_parser(description: str) -> argparse.ArgumentParser:
384 line_width = utils.cap(144, shutil.get_terminal_size().columns)
385 parser = argparse.ArgumentParser(
386 description=description,
387 formatter_class=lambda *a, **b: argparse.HelpFormatter(
388 *a, **b, max_help_position=35 if line_width >= 100 else 28, width=line_width
389 ),
390 )
391
392 default_podcast_file = "testdata/Episode_53_AdrianTchaikovsky_index"
393 parser.add_argument(
394 "--podcast",
395 type=str,
396 default=default_podcast_file,
397 help="Path to the podcast index files (excluding the '_index.json' suffix)",
398 )
399 default_qafile = "testdata/Episode_53_Answer_results.json"
400 explain_qa = "a list of questions and answers to test the full pipeline"
401 parser.add_argument(
402 "--qafile",
403 type=str,
404 default=default_qafile,
405 help=f"Path to the Answer_results.json file ({explain_qa})",
406 )
407 default_srfile = "testdata/Episode_53_Search_results.json"
408 explain_sr = "a list of intermediate results from stages 1, 2 and 3"
409 parser.add_argument(
410 "--srfile",
411 type=str,
412 default=default_srfile,
413 help=f"Path to the Search_results.json file ({explain_sr})",
414 )
415
416 batch = parser.add_argument_group("Batch mode options")
417 batch.add_argument(
418 "--batch",
419 action="store_true",
420 help="Run in batch mode, suppressing interactive prompts.",
421 )
422 batch.add_argument(
423 "--offset",
424 type=int,
425 default=0,
426 help="Number of initial Q/A pairs to skip (default none)",
427 )
428 batch.add_argument(
429 "--limit",
430 type=int,
431 default=0,
432 help="Number of Q/A pairs to process (default all)",
433 )
434 batch.add_argument(
435 "--start",
436 type=int,
437 default=0,
438 help="Do just this question (similar to --offset START-1 --limit 1)",
439 )
440
441 debug = parser.add_argument_group("Debug options")
442 debug.add_argument(
443 "--debug",
444 type=str,
445 default=None,
446 choices=["none", "diff", "full"],
447 help="Default debug level: 'none' for no debug output, 'diff' for diff output, "
448 "'full' for full debug output.",
449 )
450 arg_helper = lambda key: typing.get_args(ProcessingContext.__annotations__[key])
451 debug.add_argument(
452 "--debug1",
453 type=str,
454 default=None,
455 choices=arg_helper("debug1"),
456 help="Debug level override for stage 1: like --debug; or 'skip' to skip stage 1.",
457 )
458 debug.add_argument(
459 "--debug2",
460 type=str,
461 default=None,
462 choices=arg_helper("debug2"),
463 help="Debug level override for stage 2: like --debug; or 'skip' to skip stages 1-2.",
464 )
465 debug.add_argument(
466 "--debug3",
467 type=str,
468 default=None,
469 choices=arg_helper("debug3"),
470 help="Debug level override for stage 3: like --debug; or 'nice' to print answer only.",
471 )
472 debug.add_argument(
473 "--debug4",
474 type=str,
475 default=None,
476 choices=arg_helper("debug4"),
477 help="Debug level override for stage 4: like --debug; or 'nice' to print answer only.",
478 )
479 debug.add_argument(
480 "--alt-schema",
481 type=str,
482 default=None,
483 help="Path to alternate schema file for query translator (modifies stage 1).",
484 )
485 debug.add_argument(
486 "--show-schema",
487 action="store_true",
488 help="Show the TypeScript schema computed by typechat.",
489 )
490
491 return parser
492
493
494def fill_in_debug_defaults(
495 parser: argparse.ArgumentParser, args: argparse.Namespace
496) -> None:
497 # In batch mode, defaults are diff, diff, diff, diff.
498 # In interactive mode they are none, none, none, nice.
499 if not args.batch:
500 if args.start or args.offset or args.limit:
501 parser.exit(2, "Error: --start, --offset and --limit require --batch\n")
502 else:
503 if args.start:
504 if args.offset != 0:
505 parser.exit(2, "Error: --start and --offset can't be both set\n")
506 args.offset = args.start - 1
507 if args.limit == 0:
508 args.limit = 1
509 args.debug = args.debug or "diff"
510
511 args.debug1 = args.debug1 or args.debug or "none"
512 args.debug2 = args.debug2 or args.debug or "none"
513 args.debug3 = args.debug3 or args.debug or "none"
514 args.debug4 = args.debug4 or args.debug or "nice"
515 if args.debug2 == "skip":
516 args.debug1 = "skip" # Skipping stage 2 implies skipping stage 1.
517
518
519### Data loading ###
520
521
522def load_podcast_index(podcast_file: str) -> query.QueryEvalContext:
523 settings = importing.ConversationSettings()
524 with utils.timelog(f"load podcast from {podcast_file!r}"):
525 conversation = podcast.Podcast.read_from_file(podcast_file, settings)
526 assert conversation is not None, f"Failed to load podcast from {podcast_file!r}"
527 return query.QueryEvalContext(conversation)
528
529
530def load_index_file[T: Mapping[str, typing.Any]](
531 file: str, selector: str, cls: type[T]
532) -> tuple[list[T], dict[str, T]]:
533 # If this crashes, the file is malformed -- go figure it out.
534 try:
535 with open(file) as f:
536 lst: list[T] = json.load(f)
537 except FileNotFoundError as err:
538 print(Fore.RED + str(err) + Fore.RESET)
539 lst = []
540 index = {item[selector]: item for item in lst}
541 if len(index) != len(lst):
542 print(f"{len(lst) - len(index)} duplicate items found in {file!r}. ")
543 return lst, index
544
545
546### Debug output ###
547
548
549def print_result[TMessage: IMessage, TIndex: ITermToSemanticRefIndex](
550 result: search.ConversationSearchResult,
551 conversation: IConversation[TMessage, TIndex],
552) -> None:
553 print(
554 f"Raw query: {result.raw_query_text};",
555 f"{len(result.message_matches)} message matches,",
556 f"{len(result.knowledge_matches)} knowledge matches",
557 )
558 if result.message_matches:
559 print("Message matches:")
560 for scored_ord in sorted(
561 result.message_matches, key=lambda x: x.score, reverse=True
562 ):
563 score = scored_ord.score
564 msg_ord = scored_ord.message_ordinal
565 msg = conversation.messages[msg_ord]
566 assert msg.metadata is not None # For type checkers
567 text = " ".join(msg.text_chunks).strip()
568 print(
569 f"({score:5.1f}) M={msg_ord:d}: "
570 f"{msg.metadata.source!s:>15.15s}: "
571 f"{repr(text)[1:-1]:<150.150s} "
572 )
573 if result.knowledge_matches:
574 print(f"Knowledge matches ({', '.join(sorted(result.knowledge_matches))}):")
575 for key, value in sorted(result.knowledge_matches.items()):
576 print(f"Type {key} -- {value.term_matches}:")
577 for scored_sem_ref_ord in value.semantic_ref_matches:
578 score = scored_sem_ref_ord.score
579 sem_ref_ord = scored_sem_ref_ord.semantic_ref_ordinal
580 if conversation.semantic_refs is None:
581 print(f" Ord: {sem_ref_ord} (score {score})")
582 else:
583 sem_ref = conversation.semantic_refs[sem_ref_ord]
584 msg_ord = sem_ref.range.start.message_ordinal
585 chunk_ord = sem_ref.range.start.chunk_ordinal
586 msg = conversation.messages[msg_ord]
587 print(
588 f"({score:5.1f}) M={msg_ord}: "
589 f"S={summarize_knowledge(sem_ref)}"
590 )
591
592
593def summarize_knowledge(sem_ref: SemanticRef) -> str:
594 """Summarize the knowledge in a SemanticRef."""
595 knowledge = sem_ref.knowledge
596 if knowledge is None:
597 return f"{sem_ref.semantic_ref_ordinal}: <No knowledge>"
598 match sem_ref.knowledge_type:
599 case "entity":
600 entity = knowledge
601 assert isinstance(entity, kplib.ConcreteEntity)
602 res = [f"{entity.name} [{', '.join(entity.type)}]"]
603 if entity.facets:
604 for facet in entity.facets:
605 value = facet.value
606 if isinstance(value, kplib.Quantity):
607 value = f"{value.amount} {value.units}"
608 elif isinstance(value, float) and value.is_integer():
609 value = int(value)
610 res.append(f"<{facet.name}:{value}>")
611 return f"{sem_ref.semantic_ref_ordinal}: {' '.join(res)}"
612 case "action":
613 action = knowledge
614 assert isinstance(action, kplib.Action)
615 res = []
616 res.append("/".join(repr(verb) for verb in action.verbs))
617 if action.verb_tense:
618 res.append(f"[{action.verb_tense}]")
619 if action.subject_entity_name != "none":
620 res.append(f"subj={action.subject_entity_name!r}")
621 if action.object_entity_name != "none":
622 res.append(f"obj={action.object_entity_name!r}")
623 if action.indirect_object_entity_name != "none":
624 res.append(f"ind_obj={action.indirect_object_entity_name}")
625 if action.params:
626 for param in action.params:
627 if isinstance(param, kplib.ActionParam):
628 res.append(f"<{param.name}:{param.value}>")
629 else:
630 res.append(f"<{param}>")
631 if action.subject_entity_facet is not None:
632 res.append(f"subj_facet={action.subject_entity_facet}")
633 return f"{sem_ref.semantic_ref_ordinal}: {' '.join(res)}"
634 case "topic":
635 topic = knowledge
636 assert isinstance(topic, Topic)
637 return f"{sem_ref.semantic_ref_ordinal}: {topic.text!r}]"
638 case "tag":
639 tag = knowledge
640 assert isinstance(tag, str)
641 return f"{sem_ref.semantic_ref_ordinal}: #{tag!r}"
642 case _:
643 return f"{sem_ref.semantic_ref_ordinal}: {sem_ref.knowledge!r}"
644
645
646def compare_results(
647 matches_records: list[RawSearchResultData],
648 results: list[search.ConversationSearchResult],
649) -> bool:
650 if len(results) != len(matches_records):
651 print(f"(Result sizes mismatch, {len(results)} != {len(matches_records)})")
652 return False
653 res = True
654 for result, record in zip(results, matches_records):
655 if not compare_message_ordinals(
656 result.message_matches, record["messageMatches"]
657 ):
658 res = False
659 if not compare_semantic_ref_ordinals(
660 (
661 []
662 if "entity" not in result.knowledge_matches
663 else result.knowledge_matches["entity"].semantic_ref_matches
664 ),
665 record.get("entityMatches", []),
666 "entity",
667 ):
668 res = False
669 if not compare_semantic_ref_ordinals(
670 (
671 []
672 if "action" not in result.knowledge_matches
673 else result.knowledge_matches["action"].semantic_ref_matches
674 ),
675 record.get("actionMatches", []),
676 "action",
677 ):
678 res = False
679 if not compare_semantic_ref_ordinals(
680 (
681 []
682 if "topic" not in result.knowledge_matches
683 else result.knowledge_matches["topic"].semantic_ref_matches
684 ),
685 record.get("topicMatches", []),
686 "topic",
687 ):
688 res = False
689 return res
690
691
692# Special case: In the Podcast, these messages are all Kevin saying "Yeah",
693# so if the difference is limited to these, we consider it a match.
694NOISE_MESSAGES = frozenset({42, 46, 52, 68, 70})
695
696
697def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bool:
698 a = [aai.message_ordinal for aai in aa]
699 if set(a) ^ set(b) <= NOISE_MESSAGES:
700 return True
701 print("Message ordinals do not match:")
702 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
703 return False
704
705
706def compare_semantic_ref_ordinals(
707 aa: list[ScoredSemanticRefOrdinal], b: list[int], label: str
708) -> bool:
709 a = [aai.semantic_ref_ordinal for aai in aa]
710 if sorted(a) == sorted(b):
711 return True
712 print(f"{label.capitalize()} SemanticRef ordinals do not match:")
713 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
714 return False
715
716
717def compare_and_print_diff(a: object, b: object) -> bool: # True if equal
718 """Diff two objects whose repr() is a valid Python expression."""
719 if a == b:
720 return True
721 a_repr = repr(a)
722 b_repr = repr(b)
723 if a_repr == b_repr:
724 return True
725 # Shorten floats so slight differences in score etc. don't cause false positives.
726 a_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", a_repr)
727 b_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", b_repr)
728 if a_repr == b_repr:
729 return True
730 a_formatted = utils.format_code(a_repr)
731 b_formatted = utils.format_code(b_repr)
732 print_diff(a_formatted, b_formatted, n=2)
733 return False
734
735
736async def compare_answers(
737 context: ProcessingContext, expected: tuple[str, bool], actual: tuple[str, bool]
738) -> float:
739 expected_text, expected_success = expected
740 actual_text, actual_success = actual
741
742 if expected_success != actual_success:
743 print(f"Expected success: {expected_success}; actual: {actual_success}")
744 return 0.000
745
746 if not actual_success:
747 print(Fore.GREEN + f"Both failed" + Fore.RESET)
748 return 1.001
749
750 if expected_text == actual_text:
751 print(Fore.GREEN + f"Both equal" + Fore.RESET)
752 return 1.000
753
754 if len(expected_text.splitlines()) <= 100 and len(actual_text.splitlines()) <= 100:
755 n = 100
756 else:
757 n = 2
758 print_diff(expected_text, actual_text, n=n)
759 return await equality_score(context, expected_text, actual_text)
760
761
762def print_diff(a: str, b: str, n: int) -> None:
763 diff = difflib.unified_diff(
764 a.splitlines(),
765 b.splitlines(),
766 fromfile="expected",
767 tofile="actual",
768 n=n,
769 )
770 for x in diff:
771 if x.startswith("-"):
772 print(Fore.RED + x.rstrip("\n") + Fore.RESET)
773 elif x.startswith("+"):
774 print(Fore.GREEN + x.rstrip("\n") + Fore.RESET)
775 else:
776 print(x.rstrip("\n"))
777
778
779async def equality_score(context: ProcessingContext, a: str, b: str) -> float:
780 if a == b:
781 return 1.0
782 if a.lower() == b.lower():
783 return 0.999
784 embeddings = await context.embedding_model.get_embeddings([a, b])
785 assert embeddings.shape[0] == 2, "Expected two embeddings"
786 return np.dot(embeddings[0], embeddings[1])
787
788
789### Run main ###
790
791if __name__ == "__main__":
792 try:
793 main()
794 except (KeyboardInterrupt, BrokenPipeError):
795 print()
796 sys.exit(1)
797