microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8ba3ebd84dd1bb6343ebae028996313cabd70764

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/tools/utool.py

957lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4__version__ = "0.2"
5
6### Imports ###
7
8import argparse
9import asyncio
10from collections.abc import Mapping
11from dataclasses import dataclass
12import difflib
13import json
14import re
15import shutil
16import sys
17import typing
18
19from colorama import init as colorama_init, Fore
20import numpy as np
21
22try:
23 import readline
24except ImportError:
25 readline = None
26import typechat
27
28from typeagent.aitools import embeddings
29from typeagent.aitools import utils
30
31from typeagent.knowpro import answers, answer_response_schema
32from typeagent.knowpro import convknowledge
33from typeagent.knowpro.convsettings import ConversationSettings
34from typeagent.knowpro.interfaces import (
35 IConversation,
36 IMessage,
37 ITermToSemanticRefIndex,
38 ScoredMessageOrdinal,
39 ScoredSemanticRefOrdinal,
40 SemanticRef,
41 Tag,
42 Topic,
43)
44from typeagent.knowpro import kplib
45from typeagent.knowpro import query
46from typeagent.knowpro import search, search_query_schema, searchlang
47from typeagent.knowpro import serialization
48
49from typeagent.podcasts import podcast
50
51from typeagent.storage.memory.propindex import build_property_index
52from typeagent.storage.memory.reltermsindex import build_related_terms_index
53from typeagent.storage.sqlite.provider import SqliteStorageProvider
54from typeagent.storage.utils import create_storage_provider
55
56
57### Classes ###
58
59
60class QuestionAnswerData(typing.TypedDict):
61 question: str
62 answer: str
63 hasNoAnswer: bool
64 cmd: str
65
66
67class RawSearchResultData(typing.TypedDict):
68 messageMatches: list[int]
69 entityMatches: list[int]
70 topicMatches: list[int]
71 actionMatches: list[int]
72
73
74class SearchResultData(typing.TypedDict):
75 searchText: str
76 searchQueryExpr: dict[str, typing.Any] # Serialized search_query_schema.SearchQuery
77 compiledQueryExpr: list[dict[str, typing.Any]] # list[search.SearchQueryExpr]
78 results: list[RawSearchResultData]
79
80
81@dataclass
82class ProcessingContext:
83 query_context: query.QueryEvalContext
84 ar_list: list[QuestionAnswerData]
85 sr_list: list[SearchResultData]
86 ar_index: dict[str, QuestionAnswerData]
87 sr_index: dict[str, SearchResultData]
88 debug1: typing.Literal["none", "diff", "full", "skip"]
89 debug2: typing.Literal["none", "diff", "full", "skip"]
90 debug3: typing.Literal["none", "diff", "full", "nice"]
91 debug4: typing.Literal["none", "diff", "full", "nice"]
92 embedding_model: embeddings.AsyncEmbeddingModel
93 query_translator: typechat.TypeChatJsonTranslator[search_query_schema.SearchQuery]
94 answer_translator: typechat.TypeChatJsonTranslator[
95 answer_response_schema.AnswerResponse
96 ]
97 lang_search_options: searchlang.LanguageSearchOptions
98 answer_context_options: answers.AnswerContextOptions
99
100 def __repr__(self) -> str:
101 parts = []
102 parts.append(f"ar_list={len(self.ar_list)}")
103 parts.append(f"sr_list={len(self.sr_list)}")
104 parts.append(f"ar_index={len(self.ar_index)}")
105 parts.append(f"sr_index={len(self.sr_index)}")
106 parts.append(f"debug1={self.debug1}")
107 parts.append(f"debug2={self.debug2}")
108 parts.append(f"debug3={self.debug3}")
109 parts.append(f"debug4={self.debug4}")
110 parts.append(f"lang_search_options={self.lang_search_options}")
111 parts.append(f"answer_context_options={self.answer_context_options}")
112 return f"Context({', '.join(parts)})"
113
114
115### Main logic ###
116
117
118async def main():
119 utils.load_dotenv()
120 colorama_init(autoreset=True)
121
122 parser = make_arg_parser("TypeAgent Query Tool")
123 args = parser.parse_args()
124 fill_in_debug_defaults(parser, args)
125
126 if args.logfire:
127 utils.setup_logfire()
128
129 settings = ConversationSettings() # Has no storage provider yet
130 settings.storage_provider = await create_storage_provider(
131 settings.message_text_index_settings,
132 settings.related_term_index_settings,
133 args.database,
134 podcast.PodcastMessage,
135 )
136 query_context = await load_podcast_index(
137 args.podcast, settings, args.database, args.verbose
138 )
139
140 ar_list, ar_index = load_index_file(
141 args.qafile, "question", QuestionAnswerData, args.verbose
142 )
143 sr_list, sr_index = load_index_file(
144 args.srfile, "searchText", SearchResultData, args.verbose
145 )
146
147 model = convknowledge.create_typechat_model()
148 query_translator = utils.create_translator(model, search_query_schema.SearchQuery)
149 if args.alt_schema:
150 if args.verbose:
151 print(f"Substituting alt schema from {args.alt_schema}")
152 with open(args.alt_schema) as f:
153 query_translator.schema_str = f.read()
154 if args.show_schema:
155 print(Fore.YELLOW + query_translator.schema_str.rstrip() + Fore.RESET)
156
157 answer_translator = utils.create_translator(
158 model, answer_response_schema.AnswerResponse
159 )
160
161 context = ProcessingContext(
162 query_context,
163 ar_list,
164 sr_list,
165 ar_index,
166 sr_index,
167 args.debug1,
168 args.debug2,
169 args.debug3,
170 args.debug4,
171 settings.embedding_model,
172 query_translator,
173 answer_translator,
174 searchlang.LanguageSearchOptions(
175 compile_options=searchlang.LanguageQueryCompileOptions(
176 exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True
177 ),
178 exact_match=False,
179 max_message_matches=25,
180 ),
181 answers.AnswerContextOptions(
182 entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
183 ),
184 )
185
186 if args.verbose:
187 utils.pretty_print(context, Fore.BLUE, Fore.RESET)
188
189 if args.question is not None:
190 if args.verbose:
191 print(
192 Fore.YELLOW
193 + f"Processing single question: {args.question}"
194 + Fore.RESET
195 )
196 await process_query(context, args.question)
197 elif args.batch:
198 if args.verbose:
199 print(
200 Fore.YELLOW
201 + f"Running in batch mode [{args.offset}:{args.offset + args.limit if args.limit else ''}]."
202 + Fore.RESET
203 )
204 await batch_loop(context, args.offset, args.limit, args.skip_counters)
205 else:
206 if args.verbose:
207 print(Fore.YELLOW + "Running in interactive mode." + Fore.RESET)
208 await interactive_loop(context)
209
210
211async def print_conversation_stats(c: IConversation, verbose: bool = True) -> None:
212 if not verbose:
213 return
214 print(f"{await c.messages.size()} messages loaded.")
215 print(f"{await c.semantic_refs.size()} semantic refs loaded.")
216 print(f"{await c.semantic_ref_index.size()} sem_ref index entries.")
217 s = c.secondary_indexes
218 if s is None:
219 if verbose:
220 print("NO SECONDARY INDEXES")
221 return
222
223 if s.property_to_semantic_ref_index is None:
224 if verbose:
225 print("NO PROPERTY TO SEMANTIC REF INDEX")
226 else:
227 n = await s.property_to_semantic_ref_index.size()
228 if verbose:
229 print(f"{n} property to semantic ref index entries.")
230
231 if s.timestamp_index is None:
232 if verbose:
233 print("NO TIMESTAMP INDEX")
234 else:
235 if verbose:
236 print(f"{await s.timestamp_index.size()} timestamp index entries.")
237
238 if s.term_to_related_terms_index is None:
239 if verbose:
240 print("NO TERM TO RELATED TERMS INDEX")
241 else:
242 aliases = s.term_to_related_terms_index.aliases
243 if verbose:
244 print(f"{await aliases.size()} alias entries.")
245 f = s.term_to_related_terms_index.fuzzy_index
246 if f is None:
247 if verbose:
248 print("NO FUZZY RELATED TERMS INDEX")
249 else:
250 if verbose:
251 print(f"{await f.size()} term entries.")
252
253 if s.threads is None:
254 if verbose:
255 print("NO THREADS INDEX")
256 else:
257 if verbose:
258 print(f"{len(s.threads.threads)} threads index entries.")
259
260 if s.message_index is None:
261 if verbose:
262 print("NO MESSAGE INDEX")
263 else:
264 if verbose:
265 print(f"{await s.message_index.size()} message index entries.")
266
267
268async def batch_loop(
269 context: ProcessingContext, offset: int, limit: int, skip_counters: str
270) -> None:
271 skips = []
272 if skip_counters:
273 skips = [int(x) for x in skip_counters.split(",") if x.strip().isdigit()]
274 if limit == 0:
275 limit = len(context.ar_list) - offset
276 sublist = context.ar_list[offset : offset + limit]
277 all_scores = []
278 for counter, qadata in enumerate(sublist, offset + 1):
279 if counter in skips:
280 continue
281 question = qadata["question"]
282 print("-" * 20, counter, question, "-" * 20)
283 score = await process_query(context, question)
284 if score is not None:
285 all_scores.append((score, counter))
286 if not all_scores:
287 return
288 print("=" * 50)
289 all_scores.sort(reverse=True)
290 good_scores = [(score, counter) for score, counter in all_scores if score >= 0.97]
291 bad_scores = [(score, counter) for score, counter in all_scores if score < 0.97]
292 for label, pairs in [("Good", good_scores), ("Bad", bad_scores)]:
293 print(f"{label} scores ({len(pairs)}):")
294 for i in range(0, len(pairs), 10):
295 print(
296 ", ".join(
297 f"{score:.3f}({counter})" for score, counter in pairs[i : i + 10]
298 )
299 )
300
301
302async def interactive_loop(context: ProcessingContext) -> None:
303 if not sys.stdin.isatty():
304 for line in sys.stdin:
305 line = line.strip()
306 if not line:
307 continue
308 await process_query(context, line)
309 return
310
311 print(f"TypeAgent demo UI {__version__} (type 'q' to exit)")
312 if readline:
313 try:
314 readline.read_history_file(".ui_history")
315 except FileNotFoundError:
316 pass # Ignore if history file does not exist.
317
318 try:
319 while True:
320 try:
321 line = input("TypeAgent> ").strip()
322 except EOFError:
323 print()
324 break
325 if not line:
326 continue
327 if line.lower() in ("exit", "quit", "q"):
328 if readline:
329 readline.remove_history_item(
330 readline.get_current_history_length() - 1
331 )
332 break
333 prsep()
334 await process_query(context, line)
335
336 finally:
337 if readline:
338 readline.write_history_file(".ui_history")
339
340
341### Query processing logic ###
342
343
344async def process_query(context: ProcessingContext, query_text: str) -> float | None:
345 if not query_text.strip():
346 return # Ignore blank query (like interactive mode)
347 record = context.sr_index.get(query_text)
348 debug_context = searchlang.LanguageSearchDebugContext()
349 if context.debug1 == "skip" or context.debug2 == "skip":
350 if not record or (
351 "searchQueryExpr" not in record or "compiledQueryExpr" not in record
352 ):
353 print("Can't skip stages 1 or 2, no precomputed outcomes found.")
354 else:
355 # Skipping stage 2 implies skipping stage 1, and we must supply the
356 # precomputed results for both stages.
357 debug_context.use_search_query = serialization.deserialize_object(
358 search_query_schema.SearchQuery, record["searchQueryExpr"]
359 )
360 print("Skipping stage 1, substituting precomputed search query.")
361 if context.debug2 == "skip":
362 debug_context.use_compiled_search_query_exprs = (
363 serialization.deserialize_object(
364 list[search.SearchQueryExpr],
365 record["compiledQueryExpr"],
366 )
367 )
368 print(
369 "Skipping stage 2, substituting precomputed compiled query expressions."
370 )
371 prsep()
372
373 result = await searchlang.search_conversation_with_language(
374 context.query_context.conversation,
375 context.query_translator,
376 query_text,
377 context.lang_search_options,
378 debug_context=debug_context,
379 )
380 if isinstance(result, typechat.Failure):
381 print("Stages 1-3 failed:")
382 print(Fore.RED + str(result) + Fore.RESET)
383 return
384 search_results = result.value
385
386 actual1 = debug_context.search_query
387 if actual1:
388 if context.debug1 == "full":
389 print("Stage 1 results:")
390 utils.pretty_print(actual1, Fore.GREEN, Fore.RESET)
391 prsep()
392 elif context.debug1 == "diff":
393 if record and "searchQueryExpr" in record:
394 print("Stage 1 diff:")
395 expected1 = serialization.deserialize_object(
396 search_query_schema.SearchQuery, record["searchQueryExpr"]
397 )
398 compare_and_print_diff(expected1, actual1)
399 else:
400 print("Stage 1 diff unavailable")
401 prsep()
402
403 actual2 = debug_context.search_query_expr
404 if actual2:
405 if context.debug2 == "full":
406 print("Stage 2 results:")
407 utils.pretty_print(actual2, Fore.GREEN, Fore.RESET)
408 prsep()
409 elif context.debug2 == "diff":
410 if record and "compiledQueryExpr" in record:
411 print("Stage 2 diff:")
412 expected2 = serialization.deserialize_object(
413 list[search.SearchQueryExpr], record["compiledQueryExpr"]
414 )
415 compare_and_print_diff(expected2, actual2)
416 else:
417 print("Stage 2 diff unavailable")
418 prsep()
419
420 actual3 = search_results
421 if context.debug3 == "full":
422 print("Stage 3 full results:")
423 utils.pretty_print(actual3, Fore.GREEN, Fore.RESET)
424 prsep()
425 elif context.debug3 == "nice":
426 print("Stage 3 nice results:")
427 for sr in search_results:
428 await print_result(sr, context.query_context.conversation)
429 prsep()
430 elif context.debug3 == "diff":
431 if record and "results" in record:
432 print("Stage 3 diff:")
433 expected3: list[RawSearchResultData] = record["results"]
434 compare_results(expected3, actual3)
435 else:
436 print("Stage 3 diff unavailable")
437 prsep()
438
439 all_answers, combined_answer = await answers.generate_answers(
440 context.answer_translator,
441 search_results,
442 context.query_context.conversation,
443 query_text,
444 options=context.answer_context_options,
445 )
446
447 if context.debug4 == "full":
448 utils.pretty_print(all_answers)
449 prsep()
450 if context.debug4 in ("full", "nice"):
451 if combined_answer.type == "NoAnswer":
452 print(Fore.RED + f"Failure: {combined_answer.whyNoAnswer}" + Fore.RESET)
453 else:
454 print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET)
455 prsep()
456 elif context.debug4 == "diff":
457 if query_text in context.ar_index:
458 record = context.ar_index[query_text]
459 expected4: tuple[str, bool] = (record["answer"], not record["hasNoAnswer"])
460 print("Stage 4 diff:")
461 match combined_answer.type:
462 case "NoAnswer":
463 actual4 = (combined_answer.whyNoAnswer or "", False)
464 case "Answered":
465 actual4 = (combined_answer.answer or "", True)
466 score = await compare_answers(context, expected4, actual4)
467 if actual4[0].startswith("TypeChat failure:"):
468 print(Fore.YELLOW + "No answer received" + Fore.RESET)
469 else:
470 print(f"Score: {score:.3f}; Question: {query_text}")
471 return score
472 else:
473 print("Stage 4 diff unavailable; nice answer:")
474 if combined_answer.type == "NoAnswer":
475 print(Fore.RED + f"Failure: {combined_answer.whyNoAnswer}" + Fore.RESET)
476 else:
477 print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET)
478 prsep()
479
480
481def prsep():
482 print("-" * 50)
483
484
485### CLI processing ###
486
487
488def make_arg_parser(description: str) -> argparse.ArgumentParser:
489 line_width = utils.cap(144, shutil.get_terminal_size().columns)
490 parser = argparse.ArgumentParser(
491 description=description,
492 formatter_class=lambda *a, **b: argparse.HelpFormatter(
493 *a, **b, max_help_position=35 if line_width >= 100 else 28, width=line_width
494 ),
495 )
496
497 default_podcast_file = "testdata/Episode_53_AdrianTchaikovsky_index"
498 parser.add_argument(
499 "--podcast",
500 type=str,
501 default=default_podcast_file,
502 help="Path to the podcast index files (excluding the '_index.json' suffix)",
503 )
504 default_qafile = "testdata/Episode_53_Answer_results.json"
505 explain_qa = "a list of questions and answers to test the full pipeline"
506 parser.add_argument(
507 "--qafile",
508 type=str,
509 default=default_qafile,
510 help=f"Path to the Answer_results.json file ({explain_qa})",
511 )
512 default_srfile = "testdata/Episode_53_Search_results.json"
513 explain_sr = "a list of intermediate results from stages 1, 2 and 3"
514 parser.add_argument(
515 "--srfile",
516 type=str,
517 default=default_srfile,
518 help=f"Path to the Search_results.json file ({explain_sr})",
519 )
520 parser.add_argument(
521 "--skip-counters",
522 type=str,
523 default="",
524 help="List of comma-separated questions to skip",
525 )
526 parser.add_argument(
527 "-d",
528 "--database",
529 type=str,
530 default=None,
531 help="Path to the SQLite database file (default: in-memory)",
532 )
533 parser.add_argument(
534 "--question",
535 "--query",
536 type=str,
537 default=None,
538 help="Process a single question and exit (equivalent to echo 'question' | utool.py)",
539 )
540 parser.add_argument(
541 "-v",
542 "--verbose",
543 action="store_true",
544 help="Show verbose startup information and timing logs",
545 )
546
547 batch = parser.add_argument_group("Batch mode options")
548 batch.add_argument(
549 "--batch",
550 action="store_true",
551 help="Run in batch mode, suppressing interactive prompts.",
552 )
553 batch.add_argument(
554 "--offset",
555 type=int,
556 default=0,
557 help="Number of initial Q/A pairs to skip (default none)",
558 )
559 batch.add_argument(
560 "--limit",
561 type=int,
562 default=0,
563 help="Number of Q/A pairs to process (default all)",
564 )
565 batch.add_argument(
566 "--start",
567 type=int,
568 default=0,
569 help="Do just this question (similar to --offset START-1 --limit 1)",
570 )
571
572 debug = parser.add_argument_group("Debug options")
573 debug.add_argument(
574 "--debug",
575 type=str,
576 default=None,
577 choices=["none", "diff", "full"],
578 help="Default debug level: 'none' for no debug output, 'diff' for diff output, "
579 "'full' for full debug output.",
580 )
581 arg_helper = lambda key: typing.get_args(ProcessingContext.__annotations__[key])
582 debug.add_argument(
583 "--debug1",
584 type=str,
585 default=None,
586 choices=arg_helper("debug1"),
587 help="Debug level override for stage 1: like --debug; or 'skip' to skip stage 1.",
588 )
589 debug.add_argument(
590 "--debug2",
591 type=str,
592 default=None,
593 choices=arg_helper("debug2"),
594 help="Debug level override for stage 2: like --debug; or 'skip' to skip stages 1-2.",
595 )
596 debug.add_argument(
597 "--debug3",
598 type=str,
599 default=None,
600 choices=arg_helper("debug3"),
601 help="Debug level override for stage 3: like --debug; or 'nice' to print answer only.",
602 )
603 debug.add_argument(
604 "--debug4",
605 type=str,
606 default=None,
607 choices=arg_helper("debug4"),
608 help="Debug level override for stage 4: like --debug; or 'nice' to print answer only.",
609 )
610 debug.add_argument(
611 "--alt-schema",
612 type=str,
613 default=None,
614 help="Path to alternate schema file for query translator (modifies stage 1).",
615 )
616 debug.add_argument(
617 "--show-schema",
618 action="store_true",
619 help="Show the TypeScript schema computed by typechat.",
620 )
621 debug.add_argument(
622 "--logfire",
623 action="store_true",
624 help="Upload log events to Pydantic's Logfire server",
625 )
626
627 return parser
628
629
630def fill_in_debug_defaults(
631 parser: argparse.ArgumentParser, args: argparse.Namespace
632) -> None:
633 # In batch mode, defaults are diff, diff, diff, diff.
634 # In interactive mode they are none, none, none, nice.
635 if args.question is not None and args.batch:
636 parser.exit(2, "Error: --question cannot be combined with --batch\n")
637
638 if not args.batch:
639 if args.start or args.offset or args.limit:
640 parser.exit(2, "Error: --start, --offset and --limit require --batch\n")
641 else:
642 if args.start:
643 if args.offset != 0:
644 parser.exit(2, "Error: --start and --offset can't be both set\n")
645 args.offset = args.start - 1
646 if args.limit == 0:
647 args.limit = 1
648 args.debug = args.debug or "diff"
649
650 args.debug1 = args.debug1 or args.debug or "none"
651 args.debug2 = args.debug2 or args.debug or "none"
652 args.debug3 = args.debug3 or args.debug or "none"
653 args.debug4 = args.debug4 or args.debug or "nice"
654 if args.debug2 == "skip":
655 args.debug1 = "skip" # Skipping stage 2 implies skipping stage 1.
656
657
658### Data loading ###
659
660
661async def load_podcast_index(
662 podcast_file_prefix: str,
663 settings: ConversationSettings,
664 dbname: str | None,
665 verbose: bool = True,
666) -> query.QueryEvalContext:
667 provider = await settings.get_storage_provider()
668 msgs = await provider.get_message_collection()
669 if await msgs.size() > 0: # Sqlite provider with existing non-empty database
670 with utils.timelog(f"Reusing database {dbname!r}"):
671 conversation = await podcast.Podcast.create(settings)
672 else:
673 with utils.timelog(f"Loading podcast from {podcast_file_prefix!r}"):
674 conversation = await podcast.Podcast.read_from_file(
675 podcast_file_prefix, settings, dbname
676 )
677 if isinstance(provider, SqliteStorageProvider):
678 provider.db.commit()
679
680 await print_conversation_stats(conversation, verbose)
681
682 return query.QueryEvalContext(conversation)
683
684
685def load_index_file[T: Mapping[str, typing.Any]](
686 file: str, selector: str, cls: type[T], verbose: bool = True
687) -> tuple[list[T], dict[str, T]]:
688 # If this crashes, the file is malformed -- go figure it out.
689 try:
690 with open(file) as f:
691 lst: list[T] = json.load(f)
692 except FileNotFoundError as err:
693 print(Fore.RED + str(err) + Fore.RESET)
694 lst = []
695 index = {item[selector]: item for item in lst}
696 if len(index) != len(lst) and verbose:
697 print(f"{len(lst) - len(index)} duplicate items found in {file!r}. ")
698 return lst, index
699
700
701### Debug output ###
702
703
704async def print_result[TMessage: IMessage, TIndex: ITermToSemanticRefIndex](
705 result: search.ConversationSearchResult,
706 conversation: IConversation[TMessage, TIndex],
707) -> None:
708 print(
709 f"Raw query: {result.raw_query_text};",
710 f"{len(result.message_matches)} message matches,",
711 f"{len(result.knowledge_matches)} knowledge matches",
712 )
713 if result.message_matches:
714 print("Message matches:")
715 for scored_ord in sorted(
716 result.message_matches, key=lambda x: x.score, reverse=True
717 ):
718 score = scored_ord.score
719 msg_ord = scored_ord.message_ordinal
720 msg = await conversation.messages.get_item(msg_ord)
721 assert msg.metadata is not None # For type checkers
722 text = " ".join(msg.text_chunks).strip()
723 print(
724 f"({score:5.1f}) M={msg_ord:d}: "
725 f"{msg.metadata.source!s:>15.15s}: "
726 f"{repr(text)[1:-1]:<150.150s} "
727 )
728 if result.knowledge_matches:
729 print(f"Knowledge matches ({', '.join(sorted(result.knowledge_matches))}):")
730 for key, value in sorted(result.knowledge_matches.items()):
731 print(f"Type {key} -- {value.term_matches}:")
732 for scored_sem_ref_ord in value.semantic_ref_matches:
733 score = scored_sem_ref_ord.score
734 sem_ref_ord = scored_sem_ref_ord.semantic_ref_ordinal
735 if conversation.semantic_refs is None:
736 print(f" Ord: {sem_ref_ord} (score {score})")
737 else:
738 sem_ref = await conversation.semantic_refs.get_item(sem_ref_ord)
739 msg_ord = sem_ref.range.start.message_ordinal
740 chunk_ord = sem_ref.range.start.chunk_ordinal
741 msg = await conversation.messages.get_item(msg_ord)
742 print(
743 f"({score:5.1f}) M={msg_ord}: "
744 f"S={summarize_knowledge(sem_ref)}"
745 )
746
747
748def summarize_knowledge(sem_ref: SemanticRef) -> str:
749 """Summarize the knowledge in a SemanticRef."""
750 knowledge = sem_ref.knowledge
751 if knowledge is None:
752 return f"{sem_ref.semantic_ref_ordinal}: <No knowledge>"
753
754 if isinstance(knowledge, kplib.ConcreteEntity):
755 entity = knowledge
756 res = [f"{entity.name} [{', '.join(entity.type)}]"]
757 if entity.facets:
758 for facet in entity.facets:
759 value = facet.value
760 if isinstance(value, kplib.Quantity):
761 value = f"{value.amount} {value.units}"
762 elif isinstance(value, float) and value.is_integer():
763 value = int(value)
764 res.append(f"<{facet.name}:{value}>")
765 return f"{sem_ref.semantic_ref_ordinal}: {' '.join(res)}"
766 elif isinstance(knowledge, kplib.Action):
767 action = knowledge
768 res = []
769 res.append("/".join(repr(verb) for verb in action.verbs))
770 if action.verb_tense:
771 res.append(f"[{action.verb_tense}]")
772 if action.subject_entity_name != "none":
773 res.append(f"subj={action.subject_entity_name!r}")
774 if action.object_entity_name != "none":
775 res.append(f"obj={action.object_entity_name!r}")
776 if action.indirect_object_entity_name != "none":
777 res.append(f"ind_obj={action.indirect_object_entity_name}")
778 if action.params:
779 for param in action.params:
780 if isinstance(param, kplib.ActionParam):
781 res.append(f"<{param.name}:{param.value}>")
782 else:
783 res.append(f"<{param}>")
784 if action.subject_entity_facet is not None:
785 res.append(f"subj_facet={action.subject_entity_facet}")
786 return f"{sem_ref.semantic_ref_ordinal}: {' '.join(res)}"
787 elif isinstance(knowledge, Topic):
788 topic = knowledge
789 return f"{sem_ref.semantic_ref_ordinal}: {topic.text!r}"
790 elif isinstance(knowledge, Tag):
791 tag = knowledge
792 return f"{sem_ref.semantic_ref_ordinal}: #{tag.text!r}"
793 else:
794 return f"{sem_ref.semantic_ref_ordinal}: {sem_ref.knowledge!r}"
795
796
797def compare_results(
798 matches_records: list[RawSearchResultData],
799 results: list[search.ConversationSearchResult],
800) -> bool:
801 if len(results) != len(matches_records):
802 print(f"(Result sizes mismatch, {len(results)} != {len(matches_records)})")
803 return False
804 res = True
805 for result, record in zip(results, matches_records):
806 if not compare_message_ordinals(
807 result.message_matches, record["messageMatches"]
808 ):
809 res = False
810 if not compare_semantic_ref_ordinals(
811 (
812 []
813 if "entity" not in result.knowledge_matches
814 else result.knowledge_matches["entity"].semantic_ref_matches
815 ),
816 record.get("entityMatches", []),
817 "entity",
818 ):
819 res = False
820 if not compare_semantic_ref_ordinals(
821 (
822 []
823 if "action" not in result.knowledge_matches
824 else result.knowledge_matches["action"].semantic_ref_matches
825 ),
826 record.get("actionMatches", []),
827 "action",
828 ):
829 res = False
830 if not compare_semantic_ref_ordinals(
831 (
832 []
833 if "topic" not in result.knowledge_matches
834 else result.knowledge_matches["topic"].semantic_ref_matches
835 ),
836 record.get("topicMatches", []),
837 "topic",
838 ):
839 res = False
840 return res
841
842
843# Special case: In the Podcast, these messages are all Kevin saying "Yeah",
844# so if the difference is limited to these, we consider it a match.
845NOISE_MESSAGES = frozenset({42, 46, 52, 68, 70})
846
847
848def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bool:
849 a = [aai.message_ordinal for aai in aa]
850 if set(a) ^ set(b) <= NOISE_MESSAGES:
851 return True
852 print("Message ordinals do not match:")
853 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
854 return False
855
856
857def compare_semantic_ref_ordinals(
858 aa: list[ScoredSemanticRefOrdinal], b: list[int], label: str
859) -> bool:
860 a = [aai.semantic_ref_ordinal for aai in aa]
861 if sorted(a) == sorted(b):
862 return True
863 print(f"{label.capitalize()} SemanticRef ordinals do not match:")
864 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
865 return False
866
867
868def compare_and_print_diff(a: object, b: object) -> bool: # True if equal
869 """Diff two objects whose repr() is a valid Python expression."""
870 if a == b:
871 return True
872 a_repr = repr(a)
873 b_repr = repr(b)
874 if a_repr == b_repr:
875 return True
876 # Shorten floats so slight differences in score etc. don't cause false positives.
877 a_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", a_repr)
878 b_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", b_repr)
879 if a_repr == b_repr:
880 return True
881 a_formatted = utils.format_code(a_repr)
882 b_formatted = utils.format_code(b_repr)
883 print_diff(a_formatted, b_formatted, n=2)
884 return False
885
886
887async def compare_answers(
888 context: ProcessingContext, expected: tuple[str, bool], actual: tuple[str, bool]
889) -> float:
890 expected_text, expected_success = expected
891 actual_text, actual_success = actual
892
893 if expected_success != actual_success:
894 print(
895 f"Expected success: {Fore.RED}{expected_success}{Fore.RESET}; "
896 f"actual: {Fore.GREEN}{actual_success}{Fore.RESET}"
897 )
898 score = 0.000 if expected_success else 0.001 # 0.001 == Answer not expected
899
900 elif not actual_success:
901 print(Fore.GREEN + f"Both failed" + Fore.RESET)
902 score = 1.001
903
904 elif expected_text == actual_text:
905 print(Fore.GREEN + f"Both equal" + Fore.RESET)
906 score = 1.000
907
908 else:
909 score = await equality_score(context, expected_text, actual_text)
910
911 if len(expected_text.splitlines()) <= 100 and len(actual_text.splitlines()) <= 100:
912 n = 100
913 else:
914 n = 2
915 if score == 1.0:
916 print(actual_text)
917 else:
918 print_diff(expected_text, actual_text, n=n)
919
920 return score
921
922
923def print_diff(a: str, b: str, n: int) -> None:
924 diff = difflib.unified_diff(
925 a.splitlines(),
926 b.splitlines(),
927 fromfile="expected",
928 tofile="actual",
929 n=n,
930 )
931 for x in diff:
932 if x.startswith("-"):
933 print(Fore.RED + x.rstrip("\n") + Fore.RESET)
934 elif x.startswith("+"):
935 print(Fore.GREEN + x.rstrip("\n") + Fore.RESET)
936 else:
937 print(x.rstrip("\n"))
938
939
940async def equality_score(context: ProcessingContext, a: str, b: str) -> float:
941 if a == b:
942 return 1.0
943 if a.lower() == b.lower():
944 return 0.999
945 embeddings = await context.embedding_model.get_embeddings([a, b])
946 assert embeddings.shape[0] == 2, "Expected two embeddings"
947 return np.dot(embeddings[0], embeddings[1])
948
949
950### Run main ###
951
952if __name__ == "__main__":
953 try:
954 asyncio.run(main())
955 except (KeyboardInterrupt, BrokenPipeError):
956 print()
957 sys.exit(1)
958