microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8ba3ebd84dd1bb6343ebae028996313cabd70764

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/tools/vizcmp.py

133lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4import argparse
5import glob
6import os
7import re
8import statistics
9import sys
10
11from colorama import init as colorama_init, Back, Fore, Style
12
13
14def main():
15 parser = argparse.ArgumentParser(
16 description="Compare evaluation results from multiple files."
17 )
18 parser.add_argument(
19 "--color",
20 choices=["auto", "always", "never"],
21 default="auto",
22 help="Control color output. Default 'auto' uses colors if stdout is a terminal.",
23 )
24 parser.add_argument(
25 "files",
26 nargs="*",
27 )
28 args = parser.parse_args()
29
30 # Initialize colorama according to --color.
31 match args.color:
32 case "auto":
33 colorama_init(strip=not sys.stdout.isatty())
34 case "always":
35 colorama_init(strip=False)
36 case "never":
37 colorama_init(strip=True)
38 case _:
39 raise ValueError(f"Invalid color option: {args.color}")
40
41 files = args.files or sorted(glob.glob("evals/eval-*.txt"))
42 table = {} # {file: {counter: score, ...}, ...}
43 questions = {} # {counter: question, ...}
44
45 # Fill table with scoring data from eval files
46 for file in files:
47 with open(file, "r") as f:
48 lines = f.readlines()
49
50 scores = {}
51 counter = None
52 for i, line in enumerate(lines):
53 if m := re.match(r"^(?:-+|\*+)\s+(\d+)\s+", line):
54 counter = int(m.group(1))
55 elif m := re.match(r"^Score:\s+([\d.]+); Question:\s+(.*)$", line):
56 score = float(m.group(1))
57 scores[counter] = score
58 question = m.group(2)
59 if counter not in questions:
60 questions[counter] = question
61 elif questions[counter] != question:
62 print(f"File {file} has a different question for {counter}:")
63 print(f"< {questions[counter]}")
64 print(f"> {question}")
65
66 table[file] = scores
67
68 all_files = list(table.keys())
69 print_header(all_files)
70
71 good_counters: list[int] = [] # Counters where all columns score >= 0.97
72
73 # Print data
74 all_counters = sorted(
75 {counter for data in table.values() for counter in data.keys()},
76 key=lambda x: statistics.mean(
77 table[file].get(x) for file in all_files if table[file].get(x) is not None
78 ),
79 reverse=True,
80 )
81 for counter in all_counters:
82 all_good = True
83 print(f"{counter:>3}:", end="")
84 for file in all_files:
85 score = table[file].get(counter, None)
86 if score is None:
87 output = Fore.YELLOW + " N/A " + Fore.RESET
88 output = Style.BRIGHT + output + Style.RESET_ALL
89 else:
90 output = f"{score:.3f}"
91 output = f"{output:>6}"
92 if score >= 0.97:
93 output = Fore.GREEN + output + Fore.RESET
94 if score >= 0.999:
95 output = Style.BRIGHT + output + Style.RESET_ALL
96 else:
97 all_good = False
98 if score >= 0.9:
99 output = Fore.BLUE + output + Fore.RESET
100 else:
101 output = Fore.RED + output + Fore.RESET
102 if score == 0.0:
103 output = Style.BRIGHT + output + Style.RESET_ALL
104 print(output, end="")
105 print(f" {questions.get(counter)}")
106 if all_good:
107 good_counters.append(counter)
108
109 print_footer(all_files)
110 good_counters.sort()
111 print(f"--skip-counters={','.join(str(x) for x in good_counters)}")
112
113
114def print_header(all_files):
115 print(" ", end="")
116 for i, file in enumerate(all_files):
117 base = os.path.basename(file)
118 m = re.match(r"eval-(\d+\w*).*\.txt", base)
119 if m:
120 label = m.group(1)
121 else:
122 label = "--"
123 print(f"{label:>6}", end="")
124 print()
125
126
127def print_footer(all_files):
128 for i, file in reversed(list(enumerate(all_files))):
129 print(" |" * i + " " + os.path.basename(file))
130
131
132if __name__ == "__main__":
133 main()
134