microsoft/TypeAgent
Publicmirrored fromhttps://github.com/microsoft/TypeAgentAvailable
python/ta/tools/vizcmp.py
133lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | import argparse |
| 5 | import glob |
| 6 | import os |
| 7 | import re |
| 8 | import statistics |
| 9 | import sys |
| 10 | |
| 11 | from colorama import init as colorama_init, Back, Fore, Style |
| 12 | |
| 13 | |
| 14 | def main(): |
| 15 | parser = argparse.ArgumentParser( |
| 16 | description="Compare evaluation results from multiple files." |
| 17 | ) |
| 18 | parser.add_argument( |
| 19 | "--color", |
| 20 | choices=["auto", "always", "never"], |
| 21 | default="auto", |
| 22 | help="Control color output. Default 'auto' uses colors if stdout is a terminal.", |
| 23 | ) |
| 24 | parser.add_argument( |
| 25 | "files", |
| 26 | nargs="*", |
| 27 | ) |
| 28 | args = parser.parse_args() |
| 29 | |
| 30 | # Initialize colorama according to --color. |
| 31 | match args.color: |
| 32 | case "auto": |
| 33 | colorama_init(strip=not sys.stdout.isatty()) |
| 34 | case "always": |
| 35 | colorama_init(strip=False) |
| 36 | case "never": |
| 37 | colorama_init(strip=True) |
| 38 | case _: |
| 39 | raise ValueError(f"Invalid color option: {args.color}") |
| 40 | |
| 41 | files = args.files or sorted(glob.glob("evals/eval-*.txt")) |
| 42 | table = {} # {file: {counter: score, ...}, ...} |
| 43 | questions = {} # {counter: question, ...} |
| 44 | |
| 45 | # Fill table with scoring data from eval files |
| 46 | for file in files: |
| 47 | with open(file, "r") as f: |
| 48 | lines = f.readlines() |
| 49 | |
| 50 | scores = {} |
| 51 | counter = None |
| 52 | for i, line in enumerate(lines): |
| 53 | if m := re.match(r"^(?:-+|\*+)\s+(\d+)\s+", line): |
| 54 | counter = int(m.group(1)) |
| 55 | elif m := re.match(r"^Score:\s+([\d.]+); Question:\s+(.*)$", line): |
| 56 | score = float(m.group(1)) |
| 57 | scores[counter] = score |
| 58 | question = m.group(2) |
| 59 | if counter not in questions: |
| 60 | questions[counter] = question |
| 61 | elif questions[counter] != question: |
| 62 | print(f"File {file} has a different question for {counter}:") |
| 63 | print(f"< {questions[counter]}") |
| 64 | print(f"> {question}") |
| 65 | |
| 66 | table[file] = scores |
| 67 | |
| 68 | all_files = list(table.keys()) |
| 69 | print_header(all_files) |
| 70 | |
| 71 | good_counters: list[int] = [] # Counters where all columns score >= 0.97 |
| 72 | |
| 73 | # Print data |
| 74 | all_counters = sorted( |
| 75 | {counter for data in table.values() for counter in data.keys()}, |
| 76 | key=lambda x: statistics.mean( |
| 77 | table[file].get(x) for file in all_files if table[file].get(x) is not None |
| 78 | ), |
| 79 | reverse=True, |
| 80 | ) |
| 81 | for counter in all_counters: |
| 82 | all_good = True |
| 83 | print(f"{counter:>3}:", end="") |
| 84 | for file in all_files: |
| 85 | score = table[file].get(counter, None) |
| 86 | if score is None: |
| 87 | output = Fore.YELLOW + " N/A " + Fore.RESET |
| 88 | output = Style.BRIGHT + output + Style.RESET_ALL |
| 89 | else: |
| 90 | output = f"{score:.3f}" |
| 91 | output = f"{output:>6}" |
| 92 | if score >= 0.97: |
| 93 | output = Fore.GREEN + output + Fore.RESET |
| 94 | if score >= 0.999: |
| 95 | output = Style.BRIGHT + output + Style.RESET_ALL |
| 96 | else: |
| 97 | all_good = False |
| 98 | if score >= 0.9: |
| 99 | output = Fore.BLUE + output + Fore.RESET |
| 100 | else: |
| 101 | output = Fore.RED + output + Fore.RESET |
| 102 | if score == 0.0: |
| 103 | output = Style.BRIGHT + output + Style.RESET_ALL |
| 104 | print(output, end="") |
| 105 | print(f" {questions.get(counter)}") |
| 106 | if all_good: |
| 107 | good_counters.append(counter) |
| 108 | |
| 109 | print_footer(all_files) |
| 110 | good_counters.sort() |
| 111 | print(f"--skip-counters={','.join(str(x) for x in good_counters)}") |
| 112 | |
| 113 | |
| 114 | def print_header(all_files): |
| 115 | print(" ", end="") |
| 116 | for i, file in enumerate(all_files): |
| 117 | base = os.path.basename(file) |
| 118 | m = re.match(r"eval-(\d+\w*).*\.txt", base) |
| 119 | if m: |
| 120 | label = m.group(1) |
| 121 | else: |
| 122 | label = "--" |
| 123 | print(f"{label:>6}", end="") |
| 124 | print() |
| 125 | |
| 126 | |
| 127 | def print_footer(all_files): |
| 128 | for i, file in reversed(list(enumerate(all_files))): |
| 129 | print(" |" * i + " " + os.path.basename(file)) |
| 130 | |
| 131 | |
| 132 | if __name__ == "__main__": |
| 133 | main() |
| 134 | |