microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
cfffaac584782c49e9736bfbcee139d1f341b3e0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/tools/utool.py

956lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4__version__ = "0.2"
5
6### Imports ###
7
8import argparse
9import asyncio
10from collections.abc import Mapping
11from dataclasses import dataclass
12import difflib
13import json
14import re
15import shutil
16import sys
17import typing
18
19from colorama import init as colorama_init, Fore
20import numpy as np
21
22try:
23 import readline
24except ImportError:
25 readline = None
26import typechat
27
28from typeagent.aitools import embeddings
29from typeagent.aitools import utils
30
31from typeagent.knowpro import answers, answer_response_schema
32from typeagent.knowpro import convknowledge
33from typeagent.knowpro.convsettings import ConversationSettings
34from typeagent.knowpro.interfaces import (
35 IConversation,
36 IMessage,
37 ITermToSemanticRefIndex,
38 ScoredMessageOrdinal,
39 ScoredSemanticRefOrdinal,
40 SemanticRef,
41 Tag,
42 Topic,
43)
44from typeagent.knowpro import kplib
45from typeagent.knowpro import query
46from typeagent.knowpro import search, search_query_schema, searchlang
47from typeagent.knowpro import serialization
48
49from typeagent.podcasts import podcast
50
51from typeagent.storage.memory.propindex import build_property_index
52from typeagent.storage.memory.reltermsindex import build_related_terms_index
53from typeagent.storage.sqlite.provider import SqliteStorageProvider
54from typeagent.storage.utils import create_storage_provider
55
56
57### Classes ###
58
59
60class QuestionAnswerData(typing.TypedDict):
61 question: str
62 answer: str
63 hasNoAnswer: bool
64 cmd: str
65
66
67class RawSearchResultData(typing.TypedDict):
68 messageMatches: list[int]
69 entityMatches: list[int]
70 topicMatches: list[int]
71 actionMatches: list[int]
72
73
74class SearchResultData(typing.TypedDict):
75 searchText: str
76 searchQueryExpr: dict[str, typing.Any] # Serialized search_query_schema.SearchQuery
77 compiledQueryExpr: list[dict[str, typing.Any]] # list[search.SearchQueryExpr]
78 results: list[RawSearchResultData]
79
80
81@dataclass
82class ProcessingContext:
83 query_context: query.QueryEvalContext
84 ar_list: list[QuestionAnswerData]
85 sr_list: list[SearchResultData]
86 ar_index: dict[str, QuestionAnswerData]
87 sr_index: dict[str, SearchResultData]
88 debug1: typing.Literal["none", "diff", "full", "skip"]
89 debug2: typing.Literal["none", "diff", "full", "skip"]
90 debug3: typing.Literal["none", "diff", "full", "nice"]
91 debug4: typing.Literal["none", "diff", "full", "nice"]
92 embedding_model: embeddings.AsyncEmbeddingModel
93 query_translator: typechat.TypeChatJsonTranslator[search_query_schema.SearchQuery]
94 answer_translator: typechat.TypeChatJsonTranslator[
95 answer_response_schema.AnswerResponse
96 ]
97 lang_search_options: searchlang.LanguageSearchOptions
98 answer_context_options: answers.AnswerContextOptions
99
100 def __repr__(self) -> str:
101 parts = []
102 parts.append(f"ar_list={len(self.ar_list)}")
103 parts.append(f"sr_list={len(self.sr_list)}")
104 parts.append(f"ar_index={len(self.ar_index)}")
105 parts.append(f"sr_index={len(self.sr_index)}")
106 parts.append(f"debug1={self.debug1}")
107 parts.append(f"debug2={self.debug2}")
108 parts.append(f"debug3={self.debug3}")
109 parts.append(f"debug4={self.debug4}")
110 parts.append(f"lang_search_options={self.lang_search_options}")
111 parts.append(f"answer_context_options={self.answer_context_options}")
112 return f"Context({', '.join(parts)})"
113
114
115### Main logic ###
116
117
118async def main():
119 utils.load_dotenv()
120 colorama_init(autoreset=True)
121
122 parser = make_arg_parser("TypeAgent Query Tool")
123 args = parser.parse_args()
124 fill_in_debug_defaults(parser, args)
125
126 if args.logfire:
127 utils.setup_logfire()
128
129 settings = ConversationSettings() # Has no storage provider yet
130 settings.storage_provider = await create_storage_provider(
131 settings.message_text_index_settings,
132 settings.related_term_index_settings,
133 args.sqlite_db,
134 podcast.PodcastMessage,
135 )
136 query_context = await load_podcast_index(
137 args.podcast, settings, args.sqlite_db, args.verbose
138 )
139
140 ar_list, ar_index = load_index_file(
141 args.qafile, "question", QuestionAnswerData, args.verbose
142 )
143 sr_list, sr_index = load_index_file(
144 args.srfile, "searchText", SearchResultData, args.verbose
145 )
146
147 model = convknowledge.create_typechat_model()
148 query_translator = utils.create_translator(model, search_query_schema.SearchQuery)
149 if args.alt_schema:
150 if args.verbose:
151 print(f"Substituting alt schema from {args.alt_schema}")
152 with open(args.alt_schema) as f:
153 query_translator.schema_str = f.read()
154 if args.show_schema:
155 print(Fore.YELLOW + query_translator.schema_str.rstrip() + Fore.RESET)
156
157 answer_translator = utils.create_translator(
158 model, answer_response_schema.AnswerResponse
159 )
160
161 context = ProcessingContext(
162 query_context,
163 ar_list,
164 sr_list,
165 ar_index,
166 sr_index,
167 args.debug1,
168 args.debug2,
169 args.debug3,
170 args.debug4,
171 settings.embedding_model,
172 query_translator,
173 answer_translator,
174 searchlang.LanguageSearchOptions(
175 compile_options=searchlang.LanguageQueryCompileOptions(
176 exact_scope=False, verb_scope=True, term_filter=None, apply_scope=True
177 ),
178 exact_match=False,
179 max_message_matches=25,
180 ),
181 answers.AnswerContextOptions(
182 entities_top_k=50, topics_top_k=50, messages_top_k=None, chunking=None
183 ),
184 )
185
186 if args.verbose:
187 utils.pretty_print(context, Fore.BLUE, Fore.RESET)
188
189 if args.question is not None:
190 if args.verbose:
191 print(
192 Fore.YELLOW
193 + f"Processing single question: {args.question}"
194 + Fore.RESET
195 )
196 await process_query(context, args.question)
197 elif args.batch:
198 if args.verbose:
199 print(
200 Fore.YELLOW
201 + f"Running in batch mode [{args.offset}:{args.offset + args.limit if args.limit else ''}]."
202 + Fore.RESET
203 )
204 await batch_loop(context, args.offset, args.limit, args.skip_counters)
205 else:
206 if args.verbose:
207 print(Fore.YELLOW + "Running in interactive mode." + Fore.RESET)
208 await interactive_loop(context)
209
210
211async def print_conversation_stats(c: IConversation, verbose: bool = True) -> None:
212 if not verbose:
213 return
214 print(f"{await c.messages.size()} messages loaded.")
215 print(f"{await c.semantic_refs.size()} semantic refs loaded.")
216 print(f"{await c.semantic_ref_index.size()} sem_ref index entries.")
217 s = c.secondary_indexes
218 if s is None:
219 if verbose:
220 print("NO SECONDARY INDEXES")
221 return
222
223 if s.property_to_semantic_ref_index is None:
224 if verbose:
225 print("NO PROPERTY TO SEMANTIC REF INDEX")
226 else:
227 n = await s.property_to_semantic_ref_index.size()
228 if verbose:
229 print(f"{n} property to semantic ref index entries.")
230
231 if s.timestamp_index is None:
232 if verbose:
233 print("NO TIMESTAMP INDEX")
234 else:
235 if verbose:
236 print(f"{await s.timestamp_index.size()} timestamp index entries.")
237
238 if s.term_to_related_terms_index is None:
239 if verbose:
240 print("NO TERM TO RELATED TERMS INDEX")
241 else:
242 aliases = s.term_to_related_terms_index.aliases
243 if verbose:
244 print(f"{await aliases.size()} alias entries.")
245 f = s.term_to_related_terms_index.fuzzy_index
246 if f is None:
247 if verbose:
248 print("NO FUZZY RELATED TERMS INDEX")
249 else:
250 if verbose:
251 print(f"{await f.size()} term entries.")
252
253 if s.threads is None:
254 if verbose:
255 print("NO THREADS INDEX")
256 else:
257 if verbose:
258 print(f"{len(s.threads.threads)} threads index entries.")
259
260 if s.message_index is None:
261 if verbose:
262 print("NO MESSAGE INDEX")
263 else:
264 if verbose:
265 print(f"{await s.message_index.size()} message index entries.")
266
267
268async def batch_loop(
269 context: ProcessingContext, offset: int, limit: int, skip_counters: str
270) -> None:
271 skips = []
272 if skip_counters:
273 skips = [int(x) for x in skip_counters.split(",") if x.strip().isdigit()]
274 if limit == 0:
275 limit = len(context.ar_list) - offset
276 sublist = context.ar_list[offset : offset + limit]
277 all_scores = []
278 for counter, qadata in enumerate(sublist, offset + 1):
279 if counter in skips:
280 continue
281 question = qadata["question"]
282 print("-" * 20, counter, question, "-" * 20)
283 score = await process_query(context, question)
284 if score is not None:
285 all_scores.append((score, counter))
286 if not all_scores:
287 return
288 print("=" * 50)
289 all_scores.sort(reverse=True)
290 good_scores = [(score, counter) for score, counter in all_scores if score >= 0.97]
291 bad_scores = [(score, counter) for score, counter in all_scores if score < 0.97]
292 for label, pairs in [("Good", good_scores), ("Bad", bad_scores)]:
293 print(f"{label} scores ({len(pairs)}):")
294 for i in range(0, len(pairs), 10):
295 print(
296 ", ".join(
297 f"{score:.3f}({counter})" for score, counter in pairs[i : i + 10]
298 )
299 )
300
301
302async def interactive_loop(context: ProcessingContext) -> None:
303 if not sys.stdin.isatty():
304 for line in sys.stdin:
305 line = line.strip()
306 if not line:
307 continue
308 await process_query(context, line)
309 return
310
311 print(f"TypeAgent demo UI {__version__} (type 'q' to exit)")
312 if readline:
313 try:
314 readline.read_history_file(".ui_history")
315 except FileNotFoundError:
316 pass # Ignore if history file does not exist.
317
318 try:
319 while True:
320 try:
321 line = input("TypeAgent> ").strip()
322 except EOFError:
323 print()
324 break
325 if not line:
326 continue
327 if line.lower() in ("exit", "quit", "q"):
328 if readline:
329 readline.remove_history_item(
330 readline.get_current_history_length() - 1
331 )
332 break
333 prsep()
334 await process_query(context, line)
335
336 finally:
337 if readline:
338 readline.write_history_file(".ui_history")
339
340
341### Query processing logic ###
342
343
344async def process_query(context: ProcessingContext, query_text: str) -> float | None:
345 if not query_text.strip():
346 return # Ignore blank query (like interactive mode)
347 record = context.sr_index.get(query_text)
348 debug_context = searchlang.LanguageSearchDebugContext()
349 if context.debug1 == "skip" or context.debug2 == "skip":
350 if not record or (
351 "searchQueryExpr" not in record or "compiledQueryExpr" not in record
352 ):
353 print("Can't skip stages 1 or 2, no precomputed outcomes found.")
354 else:
355 # Skipping stage 2 implies skipping stage 1, and we must supply the
356 # precomputed results for both stages.
357 debug_context.use_search_query = serialization.deserialize_object(
358 search_query_schema.SearchQuery, record["searchQueryExpr"]
359 )
360 print("Skipping stage 1, substituting precomputed search query.")
361 if context.debug2 == "skip":
362 debug_context.use_compiled_search_query_exprs = (
363 serialization.deserialize_object(
364 list[search.SearchQueryExpr],
365 record["compiledQueryExpr"],
366 )
367 )
368 print(
369 "Skipping stage 2, substituting precomputed compiled query expressions."
370 )
371 prsep()
372
373 result = await searchlang.search_conversation_with_language(
374 context.query_context.conversation,
375 context.query_translator,
376 query_text,
377 context.lang_search_options,
378 debug_context=debug_context,
379 )
380 if isinstance(result, typechat.Failure):
381 print("Stages 1-3 failed:")
382 print(Fore.RED + str(result) + Fore.RESET)
383 return
384 search_results = result.value
385
386 actual1 = debug_context.search_query
387 if actual1:
388 if context.debug1 == "full":
389 print("Stage 1 results:")
390 utils.pretty_print(actual1, Fore.GREEN, Fore.RESET)
391 prsep()
392 elif context.debug1 == "diff":
393 if record and "searchQueryExpr" in record:
394 print("Stage 1 diff:")
395 expected1 = serialization.deserialize_object(
396 search_query_schema.SearchQuery, record["searchQueryExpr"]
397 )
398 compare_and_print_diff(expected1, actual1)
399 else:
400 print("Stage 1 diff unavailable")
401 prsep()
402
403 actual2 = debug_context.search_query_expr
404 if actual2:
405 if context.debug2 == "full":
406 print("Stage 2 results:")
407 utils.pretty_print(actual2, Fore.GREEN, Fore.RESET)
408 prsep()
409 elif context.debug2 == "diff":
410 if record and "compiledQueryExpr" in record:
411 print("Stage 2 diff:")
412 expected2 = serialization.deserialize_object(
413 list[search.SearchQueryExpr], record["compiledQueryExpr"]
414 )
415 compare_and_print_diff(expected2, actual2)
416 else:
417 print("Stage 2 diff unavailable")
418 prsep()
419
420 actual3 = search_results
421 if context.debug3 == "full":
422 print("Stage 3 full results:")
423 utils.pretty_print(actual3, Fore.GREEN, Fore.RESET)
424 prsep()
425 elif context.debug3 == "nice":
426 print("Stage 3 nice results:")
427 for sr in search_results:
428 await print_result(sr, context.query_context.conversation)
429 prsep()
430 elif context.debug3 == "diff":
431 if record and "results" in record:
432 print("Stage 3 diff:")
433 expected3: list[RawSearchResultData] = record["results"]
434 compare_results(expected3, actual3)
435 else:
436 print("Stage 3 diff unavailable")
437 prsep()
438
439 all_answers, combined_answer = await answers.generate_answers(
440 context.answer_translator,
441 search_results,
442 context.query_context.conversation,
443 query_text,
444 options=context.answer_context_options,
445 )
446
447 if context.debug4 == "full":
448 utils.pretty_print(all_answers)
449 prsep()
450 if context.debug4 in ("full", "nice"):
451 if combined_answer.type == "NoAnswer":
452 print(Fore.RED + f"Failure: {combined_answer.whyNoAnswer}" + Fore.RESET)
453 else:
454 print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET)
455 prsep()
456 elif context.debug4 == "diff":
457 if query_text in context.ar_index:
458 record = context.ar_index[query_text]
459 expected4: tuple[str, bool] = (record["answer"], not record["hasNoAnswer"])
460 print("Stage 4 diff:")
461 match combined_answer.type:
462 case "NoAnswer":
463 actual4 = (combined_answer.whyNoAnswer or "", False)
464 case "Answered":
465 actual4 = (combined_answer.answer or "", True)
466 score = await compare_answers(context, expected4, actual4)
467 if actual4[0].startswith("TypeChat failure:"):
468 print(Fore.YELLOW + "No answer received" + Fore.RESET)
469 else:
470 print(f"Score: {score:.3f}; Question: {query_text}")
471 return score
472 else:
473 print("Stage 4 diff unavailable; nice answer:")
474 if combined_answer.type == "NoAnswer":
475 print(Fore.RED + f"Failure: {combined_answer.whyNoAnswer}" + Fore.RESET)
476 else:
477 print(Fore.GREEN + f"{combined_answer.answer}" + Fore.RESET)
478 prsep()
479
480
481def prsep():
482 print("-" * 50)
483
484
485### CLI processing ###
486
487
488def make_arg_parser(description: str) -> argparse.ArgumentParser:
489 line_width = utils.cap(144, shutil.get_terminal_size().columns)
490 parser = argparse.ArgumentParser(
491 description=description,
492 formatter_class=lambda *a, **b: argparse.HelpFormatter(
493 *a, **b, max_help_position=35 if line_width >= 100 else 28, width=line_width
494 ),
495 )
496
497 default_podcast_file = "testdata/Episode_53_AdrianTchaikovsky_index"
498 parser.add_argument(
499 "--podcast",
500 type=str,
501 default=default_podcast_file,
502 help="Path to the podcast index files (excluding the '_index.json' suffix)",
503 )
504 default_qafile = "testdata/Episode_53_Answer_results.json"
505 explain_qa = "a list of questions and answers to test the full pipeline"
506 parser.add_argument(
507 "--qafile",
508 type=str,
509 default=default_qafile,
510 help=f"Path to the Answer_results.json file ({explain_qa})",
511 )
512 default_srfile = "testdata/Episode_53_Search_results.json"
513 explain_sr = "a list of intermediate results from stages 1, 2 and 3"
514 parser.add_argument(
515 "--srfile",
516 type=str,
517 default=default_srfile,
518 help=f"Path to the Search_results.json file ({explain_sr})",
519 )
520 parser.add_argument(
521 "--skip-counters",
522 type=str,
523 default="",
524 help="List of comma-separated questions to skip",
525 )
526 parser.add_argument(
527 "--sqlite-db",
528 type=str,
529 default=None,
530 help="Path to the SQLite database file (default: no SQLite)",
531 )
532 parser.add_argument(
533 "--question",
534 "--query",
535 type=str,
536 default=None,
537 help="Process a single question and exit (equivalent to echo 'question' | utool.py)",
538 )
539 parser.add_argument(
540 "-v",
541 "--verbose",
542 action="store_true",
543 help="Show verbose startup information and timing logs",
544 )
545
546 batch = parser.add_argument_group("Batch mode options")
547 batch.add_argument(
548 "--batch",
549 action="store_true",
550 help="Run in batch mode, suppressing interactive prompts.",
551 )
552 batch.add_argument(
553 "--offset",
554 type=int,
555 default=0,
556 help="Number of initial Q/A pairs to skip (default none)",
557 )
558 batch.add_argument(
559 "--limit",
560 type=int,
561 default=0,
562 help="Number of Q/A pairs to process (default all)",
563 )
564 batch.add_argument(
565 "--start",
566 type=int,
567 default=0,
568 help="Do just this question (similar to --offset START-1 --limit 1)",
569 )
570
571 debug = parser.add_argument_group("Debug options")
572 debug.add_argument(
573 "--debug",
574 type=str,
575 default=None,
576 choices=["none", "diff", "full"],
577 help="Default debug level: 'none' for no debug output, 'diff' for diff output, "
578 "'full' for full debug output.",
579 )
580 arg_helper = lambda key: typing.get_args(ProcessingContext.__annotations__[key])
581 debug.add_argument(
582 "--debug1",
583 type=str,
584 default=None,
585 choices=arg_helper("debug1"),
586 help="Debug level override for stage 1: like --debug; or 'skip' to skip stage 1.",
587 )
588 debug.add_argument(
589 "--debug2",
590 type=str,
591 default=None,
592 choices=arg_helper("debug2"),
593 help="Debug level override for stage 2: like --debug; or 'skip' to skip stages 1-2.",
594 )
595 debug.add_argument(
596 "--debug3",
597 type=str,
598 default=None,
599 choices=arg_helper("debug3"),
600 help="Debug level override for stage 3: like --debug; or 'nice' to print answer only.",
601 )
602 debug.add_argument(
603 "--debug4",
604 type=str,
605 default=None,
606 choices=arg_helper("debug4"),
607 help="Debug level override for stage 4: like --debug; or 'nice' to print answer only.",
608 )
609 debug.add_argument(
610 "--alt-schema",
611 type=str,
612 default=None,
613 help="Path to alternate schema file for query translator (modifies stage 1).",
614 )
615 debug.add_argument(
616 "--show-schema",
617 action="store_true",
618 help="Show the TypeScript schema computed by typechat.",
619 )
620 debug.add_argument(
621 "--logfire",
622 action="store_true",
623 help="Upload log events to Pydantic's Logfire server",
624 )
625
626 return parser
627
628
629def fill_in_debug_defaults(
630 parser: argparse.ArgumentParser, args: argparse.Namespace
631) -> None:
632 # In batch mode, defaults are diff, diff, diff, diff.
633 # In interactive mode they are none, none, none, nice.
634 if args.question is not None and args.batch:
635 parser.exit(2, "Error: --question cannot be combined with --batch\n")
636
637 if not args.batch:
638 if args.start or args.offset or args.limit:
639 parser.exit(2, "Error: --start, --offset and --limit require --batch\n")
640 else:
641 if args.start:
642 if args.offset != 0:
643 parser.exit(2, "Error: --start and --offset can't be both set\n")
644 args.offset = args.start - 1
645 if args.limit == 0:
646 args.limit = 1
647 args.debug = args.debug or "diff"
648
649 args.debug1 = args.debug1 or args.debug or "none"
650 args.debug2 = args.debug2 or args.debug or "none"
651 args.debug3 = args.debug3 or args.debug or "none"
652 args.debug4 = args.debug4 or args.debug or "nice"
653 if args.debug2 == "skip":
654 args.debug1 = "skip" # Skipping stage 2 implies skipping stage 1.
655
656
657### Data loading ###
658
659
660async def load_podcast_index(
661 podcast_file_prefix: str,
662 settings: ConversationSettings,
663 dbname: str | None,
664 verbose: bool = True,
665) -> query.QueryEvalContext:
666 provider = await settings.get_storage_provider()
667 msgs = await provider.get_message_collection()
668 if await msgs.size() > 0: # Sqlite provider with existing non-empty database
669 with utils.timelog(f"Reusing podcast db {dbname}"):
670 conversation = await podcast.Podcast.create(settings)
671 else:
672 with utils.timelog(f"Loading podcast from {podcast_file_prefix!r}"):
673 conversation = await podcast.Podcast.read_from_file(
674 podcast_file_prefix, settings, dbname
675 )
676 if isinstance(provider, SqliteStorageProvider):
677 provider.db.commit()
678
679 await print_conversation_stats(conversation, verbose)
680
681 return query.QueryEvalContext(conversation)
682
683
684def load_index_file[T: Mapping[str, typing.Any]](
685 file: str, selector: str, cls: type[T], verbose: bool = True
686) -> tuple[list[T], dict[str, T]]:
687 # If this crashes, the file is malformed -- go figure it out.
688 try:
689 with open(file) as f:
690 lst: list[T] = json.load(f)
691 except FileNotFoundError as err:
692 print(Fore.RED + str(err) + Fore.RESET)
693 lst = []
694 index = {item[selector]: item for item in lst}
695 if len(index) != len(lst) and verbose:
696 print(f"{len(lst) - len(index)} duplicate items found in {file!r}. ")
697 return lst, index
698
699
700### Debug output ###
701
702
703async def print_result[TMessage: IMessage, TIndex: ITermToSemanticRefIndex](
704 result: search.ConversationSearchResult,
705 conversation: IConversation[TMessage, TIndex],
706) -> None:
707 print(
708 f"Raw query: {result.raw_query_text};",
709 f"{len(result.message_matches)} message matches,",
710 f"{len(result.knowledge_matches)} knowledge matches",
711 )
712 if result.message_matches:
713 print("Message matches:")
714 for scored_ord in sorted(
715 result.message_matches, key=lambda x: x.score, reverse=True
716 ):
717 score = scored_ord.score
718 msg_ord = scored_ord.message_ordinal
719 msg = await conversation.messages.get_item(msg_ord)
720 assert msg.metadata is not None # For type checkers
721 text = " ".join(msg.text_chunks).strip()
722 print(
723 f"({score:5.1f}) M={msg_ord:d}: "
724 f"{msg.metadata.source!s:>15.15s}: "
725 f"{repr(text)[1:-1]:<150.150s} "
726 )
727 if result.knowledge_matches:
728 print(f"Knowledge matches ({', '.join(sorted(result.knowledge_matches))}):")
729 for key, value in sorted(result.knowledge_matches.items()):
730 print(f"Type {key} -- {value.term_matches}:")
731 for scored_sem_ref_ord in value.semantic_ref_matches:
732 score = scored_sem_ref_ord.score
733 sem_ref_ord = scored_sem_ref_ord.semantic_ref_ordinal
734 if conversation.semantic_refs is None:
735 print(f" Ord: {sem_ref_ord} (score {score})")
736 else:
737 sem_ref = await conversation.semantic_refs.get_item(sem_ref_ord)
738 msg_ord = sem_ref.range.start.message_ordinal
739 chunk_ord = sem_ref.range.start.chunk_ordinal
740 msg = await conversation.messages.get_item(msg_ord)
741 print(
742 f"({score:5.1f}) M={msg_ord}: "
743 f"S={summarize_knowledge(sem_ref)}"
744 )
745
746
747def summarize_knowledge(sem_ref: SemanticRef) -> str:
748 """Summarize the knowledge in a SemanticRef."""
749 knowledge = sem_ref.knowledge
750 if knowledge is None:
751 return f"{sem_ref.semantic_ref_ordinal}: <No knowledge>"
752
753 if isinstance(knowledge, kplib.ConcreteEntity):
754 entity = knowledge
755 res = [f"{entity.name} [{', '.join(entity.type)}]"]
756 if entity.facets:
757 for facet in entity.facets:
758 value = facet.value
759 if isinstance(value, kplib.Quantity):
760 value = f"{value.amount} {value.units}"
761 elif isinstance(value, float) and value.is_integer():
762 value = int(value)
763 res.append(f"<{facet.name}:{value}>")
764 return f"{sem_ref.semantic_ref_ordinal}: {' '.join(res)}"
765 elif isinstance(knowledge, kplib.Action):
766 action = knowledge
767 res = []
768 res.append("/".join(repr(verb) for verb in action.verbs))
769 if action.verb_tense:
770 res.append(f"[{action.verb_tense}]")
771 if action.subject_entity_name != "none":
772 res.append(f"subj={action.subject_entity_name!r}")
773 if action.object_entity_name != "none":
774 res.append(f"obj={action.object_entity_name!r}")
775 if action.indirect_object_entity_name != "none":
776 res.append(f"ind_obj={action.indirect_object_entity_name}")
777 if action.params:
778 for param in action.params:
779 if isinstance(param, kplib.ActionParam):
780 res.append(f"<{param.name}:{param.value}>")
781 else:
782 res.append(f"<{param}>")
783 if action.subject_entity_facet is not None:
784 res.append(f"subj_facet={action.subject_entity_facet}")
785 return f"{sem_ref.semantic_ref_ordinal}: {' '.join(res)}"
786 elif isinstance(knowledge, Topic):
787 topic = knowledge
788 return f"{sem_ref.semantic_ref_ordinal}: {topic.text!r}"
789 elif isinstance(knowledge, Tag):
790 tag = knowledge
791 return f"{sem_ref.semantic_ref_ordinal}: #{tag.text!r}"
792 else:
793 return f"{sem_ref.semantic_ref_ordinal}: {sem_ref.knowledge!r}"
794
795
796def compare_results(
797 matches_records: list[RawSearchResultData],
798 results: list[search.ConversationSearchResult],
799) -> bool:
800 if len(results) != len(matches_records):
801 print(f"(Result sizes mismatch, {len(results)} != {len(matches_records)})")
802 return False
803 res = True
804 for result, record in zip(results, matches_records):
805 if not compare_message_ordinals(
806 result.message_matches, record["messageMatches"]
807 ):
808 res = False
809 if not compare_semantic_ref_ordinals(
810 (
811 []
812 if "entity" not in result.knowledge_matches
813 else result.knowledge_matches["entity"].semantic_ref_matches
814 ),
815 record.get("entityMatches", []),
816 "entity",
817 ):
818 res = False
819 if not compare_semantic_ref_ordinals(
820 (
821 []
822 if "action" not in result.knowledge_matches
823 else result.knowledge_matches["action"].semantic_ref_matches
824 ),
825 record.get("actionMatches", []),
826 "action",
827 ):
828 res = False
829 if not compare_semantic_ref_ordinals(
830 (
831 []
832 if "topic" not in result.knowledge_matches
833 else result.knowledge_matches["topic"].semantic_ref_matches
834 ),
835 record.get("topicMatches", []),
836 "topic",
837 ):
838 res = False
839 return res
840
841
842# Special case: In the Podcast, these messages are all Kevin saying "Yeah",
843# so if the difference is limited to these, we consider it a match.
844NOISE_MESSAGES = frozenset({42, 46, 52, 68, 70})
845
846
847def compare_message_ordinals(aa: list[ScoredMessageOrdinal], b: list[int]) -> bool:
848 a = [aai.message_ordinal for aai in aa]
849 if set(a) ^ set(b) <= NOISE_MESSAGES:
850 return True
851 print("Message ordinals do not match:")
852 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
853 return False
854
855
856def compare_semantic_ref_ordinals(
857 aa: list[ScoredSemanticRefOrdinal], b: list[int], label: str
858) -> bool:
859 a = [aai.semantic_ref_ordinal for aai in aa]
860 if sorted(a) == sorted(b):
861 return True
862 print(f"{label.capitalize()} SemanticRef ordinals do not match:")
863 utils.list_diff(" Expected:", b, " Actual:", a, max_items=20)
864 return False
865
866
867def compare_and_print_diff(a: object, b: object) -> bool: # True if equal
868 """Diff two objects whose repr() is a valid Python expression."""
869 if a == b:
870 return True
871 a_repr = repr(a)
872 b_repr = repr(b)
873 if a_repr == b_repr:
874 return True
875 # Shorten floats so slight differences in score etc. don't cause false positives.
876 a_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", a_repr)
877 b_repr = re.sub(r"\b\d\.\d\d+", lambda m: f"{float(m.group()):.3f}", b_repr)
878 if a_repr == b_repr:
879 return True
880 a_formatted = utils.format_code(a_repr)
881 b_formatted = utils.format_code(b_repr)
882 print_diff(a_formatted, b_formatted, n=2)
883 return False
884
885
886async def compare_answers(
887 context: ProcessingContext, expected: tuple[str, bool], actual: tuple[str, bool]
888) -> float:
889 expected_text, expected_success = expected
890 actual_text, actual_success = actual
891
892 if expected_success != actual_success:
893 print(
894 f"Expected success: {Fore.RED}{expected_success}{Fore.RESET}; "
895 f"actual: {Fore.GREEN}{actual_success}{Fore.RESET}"
896 )
897 score = 0.000 if expected_success else 0.001 # 0.001 == Answer not expected
898
899 elif not actual_success:
900 print(Fore.GREEN + f"Both failed" + Fore.RESET)
901 score = 1.001
902
903 elif expected_text == actual_text:
904 print(Fore.GREEN + f"Both equal" + Fore.RESET)
905 score = 1.000
906
907 else:
908 score = await equality_score(context, expected_text, actual_text)
909
910 if len(expected_text.splitlines()) <= 100 and len(actual_text.splitlines()) <= 100:
911 n = 100
912 else:
913 n = 2
914 if score == 1.0:
915 print(actual_text)
916 else:
917 print_diff(expected_text, actual_text, n=n)
918
919 return score
920
921
922def print_diff(a: str, b: str, n: int) -> None:
923 diff = difflib.unified_diff(
924 a.splitlines(),
925 b.splitlines(),
926 fromfile="expected",
927 tofile="actual",
928 n=n,
929 )
930 for x in diff:
931 if x.startswith("-"):
932 print(Fore.RED + x.rstrip("\n") + Fore.RESET)
933 elif x.startswith("+"):
934 print(Fore.GREEN + x.rstrip("\n") + Fore.RESET)
935 else:
936 print(x.rstrip("\n"))
937
938
939async def equality_score(context: ProcessingContext, a: str, b: str) -> float:
940 if a == b:
941 return 1.0
942 if a.lower() == b.lower():
943 return 0.999
944 embeddings = await context.embedding_model.get_embeddings([a, b])
945 assert embeddings.shape[0] == 2, "Expected two embeddings"
946 return np.dot(embeddings[0], embeddings[1])
947
948
949### Run main ###
950
951if __name__ == "__main__":
952 try:
953 asyncio.run(main())
954 except (KeyboardInterrupt, BrokenPipeError):
955 print()
956 sys.exit(1)
957