microsoft/TypeAgent
Publicmirrored fromhttps://github.com/microsoft/TypeAgentAvailable
python/fineTuning/unsloth/baseExtract.py
255lines · modecode
| 1 | from argparse import ArgumentParser |
| 2 | import sys |
| 3 | parser = ArgumentParser(description="Extract keywords from dataset using NLTK-RAKE.") |
| 4 | parser.add_argument("--dataset_path", type=str, default='/data/npr/npr_chunks_no_embedding.json', |
| 5 | help="Path to the dataset file.") |
| 6 | # output_file |
| 7 | parser.add_argument("--output_file", type=str, default='reformatted_data.txt', |
| 8 | help="Path to the output file.") |
| 9 | parser.add_argument("--maxMsgPct", type=float, default=0.005, |
| 10 | help="Maximum message percentage threshold for filtering words (0.0-1.0). Words appearing in more than this percentage of messages will be filtered out. Default: 0.1") |
| 11 | args = parser.parse_args(sys.argv[1:]) |
| 12 | dataset_path = args.dataset_path |
| 13 | output_file = args.output_file |
| 14 | maxMsgPct = args.maxMsgPct |
| 15 | |
| 16 | # Validate maxMsgPct |
| 17 | if not 0.0 <= maxMsgPct <= 1.0: |
| 18 | print(f"Error: maxMsgPct must be between 0.0 and 1.0, got {maxMsgPct}") |
| 19 | sys.exit(1) |
| 20 | |
| 21 | import json |
| 22 | import re |
| 23 | import time |
| 24 | from dataclasses import dataclass, field |
| 25 | import matplotlib.pyplot as plt |
| 26 | import tiktoken |
| 27 | |
| 28 | @dataclass |
| 29 | class Message: |
| 30 | content: str |
| 31 | speaker: str |
| 32 | section: 'Section' = None |
| 33 | |
| 34 | @dataclass |
| 35 | class Section: |
| 36 | title: str |
| 37 | messages: list[Message] = field(default_factory=list) |
| 38 | words: set[str] = field(default_factory=set) |
| 39 | total_words: int = 0 |
| 40 | total_chars: int = 0 |
| 41 | total_tokens: int = 0 |
| 42 | |
| 43 | @dataclass |
| 44 | class Word: |
| 45 | word: str |
| 46 | messages: list[Message] = field(default_factory=list) |
| 47 | |
| 48 | def read_and_count_messages(file_path): |
| 49 | """Read the JSON file and return the array and its count.""" |
| 50 | with open(file_path, 'r') as f: |
| 51 | data = json.load(f) |
| 52 | |
| 53 | count = len(data) |
| 54 | print(f"Total messages in dataset: {count}") |
| 55 | |
| 56 | return data, count |
| 57 | |
| 58 | # Start timing |
| 59 | start_time = time.time() |
| 60 | |
| 61 | # Read the dataset |
| 62 | messages, message_count = read_and_count_messages(dataset_path) |
| 63 | |
| 64 | # Initialize tiktoken encoder |
| 65 | print("Initializing tokenizer...") |
| 66 | encoder = tiktoken.encoding_for_model("gpt-4") |
| 67 | |
| 68 | # Create dictionary of sections and process messages |
| 69 | print("Processing messages and extracting words...") |
| 70 | word_pattern = re.compile(r"\b[\w']+\b") |
| 71 | sections = {} |
| 72 | all_messages = [] |
| 73 | word_dict = {} |
| 74 | total_words = 0 |
| 75 | all_unique_words = set() |
| 76 | |
| 77 | for i, message in enumerate(messages): |
| 78 | section_title = message['section_title'] |
| 79 | if section_title not in sections: |
| 80 | sections[section_title] = Section(title=section_title) |
| 81 | |
| 82 | content = message['content'] |
| 83 | speaker = message['speaker'] |
| 84 | msg = Message(content=content, speaker=speaker, section=sections[section_title]) |
| 85 | sections[section_title].messages.append(msg) |
| 86 | all_messages.append(msg) |
| 87 | |
| 88 | # Extract words from this message |
| 89 | words = word_pattern.findall(content.lower()) |
| 90 | # Filter out numbers less than 100 |
| 91 | words = [w for w in words if not (w.isdigit() and int(w) < 100)] |
| 92 | word_count = len(words) |
| 93 | char_count = len(content) |
| 94 | token_count = len(encoder.encode(content)) |
| 95 | total_words += word_count |
| 96 | sections[section_title].total_words += word_count |
| 97 | sections[section_title].total_chars += char_count |
| 98 | sections[section_title].total_tokens += token_count |
| 99 | sections[section_title].words.update(words) |
| 100 | all_unique_words.update(words) |
| 101 | |
| 102 | # Track which messages contain each word |
| 103 | unique_message_words = set(words) |
| 104 | for word in unique_message_words: |
| 105 | if word not in word_dict: |
| 106 | word_dict[word] = Word(word=word) |
| 107 | word_dict[word].messages.append(msg) |
| 108 | |
| 109 | print(f"Processing complete.") |
| 110 | |
| 111 | # Filter words based on maxMsgPct |
| 112 | print(f"\nFiltering words with message percentage > {maxMsgPct * 100:.1f}%...") |
| 113 | total_message_count = len(all_messages) |
| 114 | max_messages_threshold = maxMsgPct * total_message_count |
| 115 | |
| 116 | # Filter overall word_dict |
| 117 | words_before_filter = len(word_dict) |
| 118 | filtered_word_dict = {word: word_obj for word, word_obj in word_dict.items() |
| 119 | if len(word_obj.messages) < max_messages_threshold} |
| 120 | words_filtered = words_before_filter - len(filtered_word_dict) |
| 121 | print(f" Filtered {words_filtered} words from overall word list ({words_before_filter} -> {len(filtered_word_dict)})") |
| 122 | |
| 123 | # Filter each section's word set |
| 124 | for section in sections.values(): |
| 125 | words_to_remove = {word for word in section.words |
| 126 | if word not in filtered_word_dict} |
| 127 | section.words -= words_to_remove |
| 128 | |
| 129 | # Update word_dict to filtered version |
| 130 | word_dict = filtered_word_dict |
| 131 | all_unique_words = set(word_dict.keys()) |
| 132 | |
| 133 | print(f"Filtering complete.") |
| 134 | |
| 135 | # Print statistics |
| 136 | section_count = len(sections) |
| 137 | total_messages = sum(len(section.messages) for section in sections.values()) |
| 138 | avg_messages_per_section = total_messages / section_count if section_count > 0 else 0 |
| 139 | avg_words_per_section = sum(len(section.words) for section in sections.values()) / section_count if section_count > 0 else 0 |
| 140 | |
| 141 | # Calculate word density ratio for each section |
| 142 | word_density_ratios = [] |
| 143 | for section in sections.values(): |
| 144 | if section.total_words > 0: |
| 145 | ratio = len(section.words) / section.total_words |
| 146 | word_density_ratios.append(ratio) |
| 147 | |
| 148 | avg_word_density = sum(word_density_ratios) / len(word_density_ratios) if word_density_ratios else 0 |
| 149 | |
| 150 | total_message_references = sum(len(word.messages) for word in word_dict.values()) |
| 151 | avg_messages_per_word = total_message_references / len(word_dict) if word_dict else 0 |
| 152 | |
| 153 | # Calculate token statistics |
| 154 | import statistics |
| 155 | token_counts = [section.total_tokens for section in sections.values()] |
| 156 | median_tokens_per_section = statistics.median(token_counts) if token_counts else 0 |
| 157 | mean_tokens_per_section = statistics.mean(token_counts) if token_counts else 0 |
| 158 | |
| 159 | print(f"Total sections: {section_count}") |
| 160 | print(f"Average messages per section: {avg_messages_per_section:.2f}") |
| 161 | print(f"Total words: {total_words}") |
| 162 | print(f"Unique words across all messages: {len(all_unique_words)}") |
| 163 | print(f"Average unique words per section: {avg_words_per_section:.2f}") |
| 164 | print(f"Average word density ratio (unique/total per section): {avg_word_density * 100:.3f}%") |
| 165 | print(f"Overall word density ratio (unique/total across all messages): {len(all_unique_words) / total_words * 100:.3f}%") |
| 166 | print(f"Average messages per unique word: {avg_messages_per_word:.2f}") |
| 167 | print(f"Mean tokens per section: {mean_tokens_per_section:.2f}") |
| 168 | print(f"Median tokens per section: {median_tokens_per_section:.2f}") |
| 169 | |
| 170 | # Analyze message count distribution |
| 171 | message_counts = [len(word.messages) for word in word_dict.values()] |
| 172 | message_counts_sorted = sorted(message_counts, reverse=True) |
| 173 | |
| 174 | # Get top 5 words by message count |
| 175 | top_words = sorted(word_dict.values(), key=lambda w: len(w.messages), reverse=True)[:5] |
| 176 | print(f"\nTop 5 words by message count:") |
| 177 | for i, word_obj in enumerate(top_words, 1): |
| 178 | print(f" {i}. '{word_obj.word}': {len(word_obj.messages)} messages") |
| 179 | |
| 180 | # Calculate statistics |
| 181 | import statistics |
| 182 | median_messages = statistics.median(message_counts) if message_counts else 0 |
| 183 | max_messages = max(message_counts) if message_counts else 0 |
| 184 | min_messages = min(message_counts) if message_counts else 0 |
| 185 | |
| 186 | print(f"\nMessage count statistics per word:") |
| 187 | print(f" Min: {min_messages}") |
| 188 | print(f" Median: {median_messages:.2f}") |
| 189 | print(f" Max: {max_messages}") |
| 190 | |
| 191 | # Count words appearing in 4 or fewer messages |
| 192 | words_4_or_fewer = sum(1 for count in message_counts if count <= 4) |
| 193 | percentage_4_or_fewer = (words_4_or_fewer / len(message_counts) * 100) if message_counts else 0 |
| 194 | print(f" Words appearing in ≤4 messages: {words_4_or_fewer} ({percentage_4_or_fewer:.2f}%)") |
| 195 | |
| 196 | # Count words appearing in 16 or fewer messages |
| 197 | words_16_or_fewer = sum(1 for count in message_counts if count <= 16) |
| 198 | percentage_16_or_fewer = (words_16_or_fewer / len(message_counts) * 100) if message_counts else 0 |
| 199 | print(f" Words appearing in ≤16 messages: {words_16_or_fewer} ({percentage_16_or_fewer:.2f}%)") |
| 200 | |
| 201 | # Plot distribution of messages per word |
| 202 | print("\nCreating distribution plot...") |
| 203 | message_counts = [len(word.messages) for word in word_dict.values()] |
| 204 | |
| 205 | plt.figure(figsize=(12, 6)) |
| 206 | plt.hist(message_counts, bins=50, edgecolor='black', alpha=0.7) |
| 207 | plt.xlabel('Number of Messages per Word') |
| 208 | plt.ylabel('Number of Words') |
| 209 | plt.title('Distribution of Message Counts per Unique Word') |
| 210 | plt.yscale('log') # Use log scale for y-axis since distribution is likely skewed |
| 211 | plt.grid(True, alpha=0.3) |
| 212 | plt.tight_layout() |
| 213 | plt.savefig('word_message_distribution.png', dpi=150) |
| 214 | print(f"Distribution plot saved to word_message_distribution.png") |
| 215 | |
| 216 | # Plot distribution for words with 50 or fewer messages |
| 217 | message_counts_50_or_less = [count for count in message_counts if count <= 50] |
| 218 | plt.figure(figsize=(12, 6)) |
| 219 | plt.hist(message_counts_50_or_less, bins=50, edgecolor='black', alpha=0.7) |
| 220 | plt.xlabel('Number of Messages per Word') |
| 221 | plt.ylabel('Number of Words') |
| 222 | plt.title('Distribution of Message Counts per Unique Word (≤50 messages)') |
| 223 | plt.yscale('log') |
| 224 | plt.grid(True, alpha=0.3) |
| 225 | plt.tight_layout() |
| 226 | plt.savefig('word_message_distribution_50_or_less.png', dpi=150) |
| 227 | print(f"Zoomed distribution plot saved to word_message_distribution_50_or_less.png") |
| 228 | |
| 229 | # Write data to output file |
| 230 | print(f"\nWriting data to {output_file}...") |
| 231 | output_data = [] |
| 232 | for section in sections.values(): |
| 233 | section_data = { |
| 234 | "title": section.title, |
| 235 | "messages": [ |
| 236 | { |
| 237 | "content": msg.content, |
| 238 | "speaker": msg.speaker |
| 239 | } |
| 240 | for msg in section.messages |
| 241 | ], |
| 242 | "words": sorted(list(section.words)), |
| 243 | "word_count": section.total_words, |
| 244 | "char_count": section.total_chars, |
| 245 | "token_count": section.total_tokens |
| 246 | } |
| 247 | output_data.append(section_data) |
| 248 | |
| 249 | with open(output_file, 'w') as f: |
| 250 | json.dump(output_data, f, indent=2) |
| 251 | print(f"Data written to {output_file}") |
| 252 | |
| 253 | # Print elapsed time |
| 254 | elapsed_time = time.time() - start_time |
| 255 | print(f"\nTotal processing time: {elapsed_time:.2f} seconds") |