microsoft/TypeAgent
Publicmirrored fromhttps://github.com/microsoft/TypeAgentAvailable
python/fineTuning/unsloth/baseExtract.py
324lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | from argparse import ArgumentParser |
| 5 | import sys |
| 6 | parser = ArgumentParser(description="Extract keywords from dataset using NLTK-RAKE.") |
| 7 | parser.add_argument("--dataset_path", type=str, default='/data/npr/npr_chunks_no_embedding.json', |
| 8 | help="Path to the dataset file.") |
| 9 | # output_file |
| 10 | parser.add_argument("--output_file", type=str, default='reformatted_data.txt', |
| 11 | help="Path to the output file.") |
| 12 | parser.add_argument("--maxMsgPct", type=float, default=0.005, |
| 13 | help="Maximum message percentage threshold for filtering words (0.0-1.0). Words appearing in more than this percentage of messages will be filtered out. Default: 0.1") |
| 14 | args = parser.parse_args(sys.argv[1:]) |
| 15 | dataset_path = args.dataset_path |
| 16 | output_file = args.output_file |
| 17 | maxMsgPct = args.maxMsgPct |
| 18 | |
| 19 | # Validate maxMsgPct |
| 20 | if not 0.0 <= maxMsgPct <= 1.0: |
| 21 | print(f"Error: maxMsgPct must be between 0.0 and 1.0, got {maxMsgPct}") |
| 22 | sys.exit(1) |
| 23 | |
| 24 | import json |
| 25 | import re |
| 26 | import time |
| 27 | from dataclasses import dataclass, field |
| 28 | import matplotlib.pyplot as plt |
| 29 | import tiktoken |
| 30 | |
| 31 | @dataclass |
| 32 | class Message: |
| 33 | content: str |
| 34 | speaker: str |
| 35 | section: 'Section' = None |
| 36 | |
| 37 | @dataclass |
| 38 | class Section: |
| 39 | title: str |
| 40 | messages: list[Message] = field(default_factory=list) |
| 41 | words: set[str] = field(default_factory=set) |
| 42 | proper_nouns: set[str] = field(default_factory=set) |
| 43 | total_words: int = 0 |
| 44 | total_chars: int = 0 |
| 45 | total_tokens: int = 0 |
| 46 | |
| 47 | @dataclass |
| 48 | class Word: |
| 49 | word: str |
| 50 | messages: list[Message] = field(default_factory=list) |
| 51 | |
| 52 | def extract_proper_nouns(text): |
| 53 | """Extract proper nouns (capitalized words not at sentence start). |
| 54 | Prefers 2-word proper nouns over 1-word. Excludes pronoun 'I' and its contractions. |
| 55 | Handles quoted dialogue to avoid capturing sentence-initial words within quotes.""" |
| 56 | # Split into sentences (simple approach) |
| 57 | sentences = re.split(r'[.!?]+\s+', text) |
| 58 | proper_nouns = set() |
| 59 | |
| 60 | # Words to exclude: pronoun "I" and its contractions |
| 61 | excluded_words = {"I", "I'd", "I'll", "I'm", "I've"} |
| 62 | |
| 63 | for sentence in sentences: |
| 64 | words = sentence.split() |
| 65 | i = 1 # Start from index 1 to skip sentence-initial capitalization |
| 66 | |
| 67 | while i < len(words): |
| 68 | # Clean word of punctuation |
| 69 | word1 = re.sub(r'[^\w\']+$', '', words[i]) |
| 70 | word1_clean = re.sub(r'^[^\w\']+', '', word1) |
| 71 | |
| 72 | # Check if previous word ends with opening quote (indicating quoted sentence start) |
| 73 | is_quote_start = False |
| 74 | if i > 0: |
| 75 | prev_word = words[i - 1] |
| 76 | # Check if previous word ends with quote marks |
| 77 | if prev_word and prev_word[-1] in ['"', '"', '"', '\'']: |
| 78 | is_quote_start = True |
| 79 | |
| 80 | # Check if capitalized and not an excluded word or quote-initial word |
| 81 | if word1_clean and word1_clean[0].isupper() and word1_clean not in excluded_words and not is_quote_start: |
| 82 | # Check for 2-word proper noun |
| 83 | if i + 1 < len(words): |
| 84 | word2 = re.sub(r'[^\w\']+$', '', words[i + 1]) |
| 85 | word2_clean = re.sub(r'^[^\w\']+', '', word2) |
| 86 | if word2_clean and word2_clean[0].isupper() and word2_clean not in excluded_words: |
| 87 | # Found 2-word proper noun |
| 88 | proper_nouns.add(f"{word1_clean} {word2_clean}") |
| 89 | i += 2 |
| 90 | continue |
| 91 | |
| 92 | # Single-word proper noun |
| 93 | proper_nouns.add(word1_clean) |
| 94 | i += 1 |
| 95 | |
| 96 | return proper_nouns |
| 97 | |
| 98 | def read_and_count_messages(file_path): |
| 99 | """Read the JSON file and return the array and its count.""" |
| 100 | with open(file_path, 'r') as f: |
| 101 | data = json.load(f) |
| 102 | |
| 103 | count = len(data) |
| 104 | print(f"Total messages in dataset: {count}") |
| 105 | |
| 106 | return data, count |
| 107 | |
| 108 | # Start timing |
| 109 | start_time = time.time() |
| 110 | |
| 111 | # Read the dataset |
| 112 | messages, message_count = read_and_count_messages(dataset_path) |
| 113 | |
| 114 | # Initialize tiktoken encoder |
| 115 | print("Initializing tokenizer...") |
| 116 | encoder = tiktoken.encoding_for_model("gpt-4") |
| 117 | |
| 118 | # Create dictionary of sections and process messages |
| 119 | print("Processing messages and extracting words...") |
| 120 | word_pattern = re.compile(r"\b[\w']+\b") |
| 121 | sections = {} |
| 122 | all_messages = [] |
| 123 | word_dict = {} |
| 124 | total_words = 0 |
| 125 | all_unique_words = set() |
| 126 | |
| 127 | for i, message in enumerate(messages): |
| 128 | section_title = message['section_title'] |
| 129 | if section_title not in sections: |
| 130 | sections[section_title] = Section(title=section_title) |
| 131 | |
| 132 | content = message['content'] |
| 133 | speaker = message['speaker'] |
| 134 | msg = Message(content=content, speaker=speaker, section=sections[section_title]) |
| 135 | sections[section_title].messages.append(msg) |
| 136 | all_messages.append(msg) |
| 137 | |
| 138 | # Extract proper nouns |
| 139 | proper_nouns = extract_proper_nouns(content) |
| 140 | sections[section_title].proper_nouns.update(proper_nouns) |
| 141 | |
| 142 | # Build set of words to exclude (all words from proper nouns, lowercased) |
| 143 | words_to_exclude = set() |
| 144 | for proper_noun in proper_nouns: |
| 145 | words_to_exclude.update(word.lower() for word in proper_noun.split()) |
| 146 | |
| 147 | # Extract words from this message |
| 148 | words = word_pattern.findall(content.lower()) |
| 149 | # Filter out numbers less than 100 and proper noun words |
| 150 | words = [w for w in words if not (w.isdigit() and int(w) < 100) and w not in words_to_exclude] |
| 151 | word_count = len(words) |
| 152 | char_count = len(content) |
| 153 | token_count = len(encoder.encode(content)) |
| 154 | total_words += word_count |
| 155 | sections[section_title].total_words += word_count |
| 156 | sections[section_title].total_chars += char_count |
| 157 | sections[section_title].total_tokens += token_count |
| 158 | sections[section_title].words.update(words) |
| 159 | all_unique_words.update(words) |
| 160 | |
| 161 | # Track which messages contain each word |
| 162 | unique_message_words = set(words) |
| 163 | for word in unique_message_words: |
| 164 | if word not in word_dict: |
| 165 | word_dict[word] = Word(word=word) |
| 166 | word_dict[word].messages.append(msg) |
| 167 | |
| 168 | print(f"Processing complete.") |
| 169 | |
| 170 | # Filter words based on maxMsgPct |
| 171 | print(f"\nFiltering words with message percentage > {maxMsgPct * 100:.1f}%...") |
| 172 | total_message_count = len(all_messages) |
| 173 | max_messages_threshold = maxMsgPct * total_message_count |
| 174 | |
| 175 | # Filter overall word_dict |
| 176 | words_before_filter = len(word_dict) |
| 177 | filtered_word_dict = {word: word_obj for word, word_obj in word_dict.items() |
| 178 | if len(word_obj.messages) < max_messages_threshold} |
| 179 | words_filtered = words_before_filter - len(filtered_word_dict) |
| 180 | print(f" Filtered {words_filtered} words from overall word list ({words_before_filter} -> {len(filtered_word_dict)})") |
| 181 | |
| 182 | # Filter each section's word set |
| 183 | for section in sections.values(): |
| 184 | words_to_remove = {word for word in section.words |
| 185 | if word not in filtered_word_dict} |
| 186 | section.words -= words_to_remove |
| 187 | |
| 188 | # Update word_dict to filtered version |
| 189 | word_dict = filtered_word_dict |
| 190 | all_unique_words = set(word_dict.keys()) |
| 191 | |
| 192 | print(f"Filtering complete.") |
| 193 | |
| 194 | # Print statistics |
| 195 | section_count = len(sections) |
| 196 | total_messages = sum(len(section.messages) for section in sections.values()) |
| 197 | avg_messages_per_section = total_messages / section_count if section_count > 0 else 0 |
| 198 | avg_words_per_section = sum(len(section.words) for section in sections.values()) / section_count if section_count > 0 else 0 |
| 199 | |
| 200 | # Calculate word density ratio for each section |
| 201 | word_density_ratios = [] |
| 202 | for section in sections.values(): |
| 203 | if section.total_words > 0: |
| 204 | ratio = len(section.words) / section.total_words |
| 205 | word_density_ratios.append(ratio) |
| 206 | |
| 207 | avg_word_density = sum(word_density_ratios) / len(word_density_ratios) if word_density_ratios else 0 |
| 208 | |
| 209 | total_message_references = sum(len(word.messages) for word in word_dict.values()) |
| 210 | avg_messages_per_word = total_message_references / len(word_dict) if word_dict else 0 |
| 211 | |
| 212 | # Calculate token statistics |
| 213 | import statistics |
| 214 | token_counts = [section.total_tokens for section in sections.values()] |
| 215 | median_tokens_per_section = statistics.median(token_counts) if token_counts else 0 |
| 216 | mean_tokens_per_section = statistics.mean(token_counts) if token_counts else 0 |
| 217 | |
| 218 | # Calculate proper noun statistics |
| 219 | all_unique_proper_nouns = set() |
| 220 | for section in sections.values(): |
| 221 | all_unique_proper_nouns.update(section.proper_nouns) |
| 222 | |
| 223 | avg_proper_nouns_per_section = sum(len(section.proper_nouns) for section in sections.values()) / section_count if section_count > 0 else 0 |
| 224 | |
| 225 | print(f"Total sections: {section_count}") |
| 226 | print(f"Average messages per section: {avg_messages_per_section:.2f}") |
| 227 | print(f"Total words: {total_words}") |
| 228 | print(f"Unique words across all messages: {len(all_unique_words)}") |
| 229 | print(f"Average unique words per section: {avg_words_per_section:.2f}") |
| 230 | print(f"Total unique proper nouns across all messages: {len(all_unique_proper_nouns)}") |
| 231 | print(f"Average unique proper nouns per section: {avg_proper_nouns_per_section:.2f}") |
| 232 | print(f"Average word density ratio (unique/total per section): {avg_word_density * 100:.3f}%") |
| 233 | print(f"Overall word density ratio (unique/total across all messages): {len(all_unique_words) / total_words * 100:.3f}%") |
| 234 | print(f"Average messages per unique word: {avg_messages_per_word:.2f}") |
| 235 | print(f"Mean tokens per section: {mean_tokens_per_section:.2f}") |
| 236 | print(f"Median tokens per section: {median_tokens_per_section:.2f}") |
| 237 | |
| 238 | # Analyze message count distribution |
| 239 | message_counts = [len(word.messages) for word in word_dict.values()] |
| 240 | message_counts_sorted = sorted(message_counts, reverse=True) |
| 241 | |
| 242 | # Get top 5 words by message count |
| 243 | top_words = sorted(word_dict.values(), key=lambda w: len(w.messages), reverse=True)[:5] |
| 244 | print(f"\nTop 5 words by message count:") |
| 245 | for i, word_obj in enumerate(top_words, 1): |
| 246 | print(f" {i}. '{word_obj.word}': {len(word_obj.messages)} messages") |
| 247 | |
| 248 | # Calculate statistics |
| 249 | import statistics |
| 250 | median_messages = statistics.median(message_counts) if message_counts else 0 |
| 251 | max_messages = max(message_counts) if message_counts else 0 |
| 252 | min_messages = min(message_counts) if message_counts else 0 |
| 253 | |
| 254 | print(f"\nMessage count statistics per word:") |
| 255 | print(f" Min: {min_messages}") |
| 256 | print(f" Median: {median_messages:.2f}") |
| 257 | print(f" Max: {max_messages}") |
| 258 | |
| 259 | # Count words appearing in 4 or fewer messages |
| 260 | words_4_or_fewer = sum(1 for count in message_counts if count <= 4) |
| 261 | percentage_4_or_fewer = (words_4_or_fewer / len(message_counts) * 100) if message_counts else 0 |
| 262 | print(f" Words appearing in ≤4 messages: {words_4_or_fewer} ({percentage_4_or_fewer:.2f}%)") |
| 263 | |
| 264 | # Count words appearing in 16 or fewer messages |
| 265 | words_16_or_fewer = sum(1 for count in message_counts if count <= 16) |
| 266 | percentage_16_or_fewer = (words_16_or_fewer / len(message_counts) * 100) if message_counts else 0 |
| 267 | print(f" Words appearing in ≤16 messages: {words_16_or_fewer} ({percentage_16_or_fewer:.2f}%)") |
| 268 | |
| 269 | # Plot distribution of messages per word |
| 270 | print("\nCreating distribution plot...") |
| 271 | message_counts = [len(word.messages) for word in word_dict.values()] |
| 272 | |
| 273 | plt.figure(figsize=(12, 6)) |
| 274 | plt.hist(message_counts, bins=50, edgecolor='black', alpha=0.7) |
| 275 | plt.xlabel('Number of Messages per Word') |
| 276 | plt.ylabel('Number of Words') |
| 277 | plt.title('Distribution of Message Counts per Unique Word') |
| 278 | plt.yscale('log') # Use log scale for y-axis since distribution is likely skewed |
| 279 | plt.grid(True, alpha=0.3) |
| 280 | plt.tight_layout() |
| 281 | plt.savefig('word_message_distribution.png', dpi=150) |
| 282 | print(f"Distribution plot saved to word_message_distribution.png") |
| 283 | |
| 284 | # Plot distribution for words with 50 or fewer messages |
| 285 | message_counts_50_or_less = [count for count in message_counts if count <= 50] |
| 286 | plt.figure(figsize=(12, 6)) |
| 287 | plt.hist(message_counts_50_or_less, bins=50, edgecolor='black', alpha=0.7) |
| 288 | plt.xlabel('Number of Messages per Word') |
| 289 | plt.ylabel('Number of Words') |
| 290 | plt.title('Distribution of Message Counts per Unique Word (≤50 messages)') |
| 291 | plt.yscale('log') |
| 292 | plt.grid(True, alpha=0.3) |
| 293 | plt.tight_layout() |
| 294 | plt.savefig('word_message_distribution_50_or_less.png', dpi=150) |
| 295 | print(f"Zoomed distribution plot saved to word_message_distribution_50_or_less.png") |
| 296 | |
| 297 | # Write data to output file |
| 298 | print(f"\nWriting data to {output_file}...") |
| 299 | output_data = [] |
| 300 | for section in sections.values(): |
| 301 | section_data = { |
| 302 | "title": section.title, |
| 303 | "messages": [ |
| 304 | { |
| 305 | "content": msg.content, |
| 306 | "speaker": msg.speaker |
| 307 | } |
| 308 | for msg in section.messages |
| 309 | ], |
| 310 | "words": sorted(list(section.words)), |
| 311 | "proper_nouns": sorted(list(section.proper_nouns)), |
| 312 | "word_count": section.total_words, |
| 313 | "char_count": section.total_chars, |
| 314 | "token_count": section.total_tokens |
| 315 | } |
| 316 | output_data.append(section_data) |
| 317 | |
| 318 | with open(output_file, 'w') as f: |
| 319 | json.dump(output_data, f, indent=2) |
| 320 | print(f"Data written to {output_file}") |
| 321 | |
| 322 | # Print elapsed time |
| 323 | elapsed_time = time.time() - start_time |
| 324 | print(f"\nTotal processing time: {elapsed_time:.2f} seconds") |