microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ecec6489db2ad85ccd1be99c0860dc3b057af2b1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/fineTuning/unsloth/baseExtract.py

324lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4from argparse import ArgumentParser
5import sys
6parser = ArgumentParser(description="Extract keywords from dataset using NLTK-RAKE.")
7parser.add_argument("--dataset_path", type=str, default='/data/npr/npr_chunks_no_embedding.json',
8help="Path to the dataset file.")
9# output_file
10parser.add_argument("--output_file", type=str, default='reformatted_data.txt',
11help="Path to the output file.")
12parser.add_argument("--maxMsgPct", type=float, default=0.005,
13help="Maximum message percentage threshold for filtering words (0.0-1.0). Words appearing in more than this percentage of messages will be filtered out. Default: 0.1")
14args = parser.parse_args(sys.argv[1:])
15dataset_path = args.dataset_path
16output_file = args.output_file
17maxMsgPct = args.maxMsgPct
18
19# Validate maxMsgPct
20if not 0.0 <= maxMsgPct <= 1.0:
21 print(f"Error: maxMsgPct must be between 0.0 and 1.0, got {maxMsgPct}")
22 sys.exit(1)
23
24import json
25import re
26import time
27from dataclasses import dataclass, field
28import matplotlib.pyplot as plt
29import tiktoken
30
31@dataclass
32class Message:
33 content: str
34 speaker: str
35 section: 'Section' = None
36
37@dataclass
38class Section:
39 title: str
40 messages: list[Message] = field(default_factory=list)
41 words: set[str] = field(default_factory=set)
42 proper_nouns: set[str] = field(default_factory=set)
43 total_words: int = 0
44 total_chars: int = 0
45 total_tokens: int = 0
46
47@dataclass
48class Word:
49 word: str
50 messages: list[Message] = field(default_factory=list)
51
52def extract_proper_nouns(text):
53 """Extract proper nouns (capitalized words not at sentence start).
54 Prefers 2-word proper nouns over 1-word. Excludes pronoun 'I' and its contractions.
55 Handles quoted dialogue to avoid capturing sentence-initial words within quotes."""
56 # Split into sentences (simple approach)
57 sentences = re.split(r'[.!?]+\s+', text)
58 proper_nouns = set()
59
60 # Words to exclude: pronoun "I" and its contractions
61 excluded_words = {"I", "I'd", "I'll", "I'm", "I've"}
62
63 for sentence in sentences:
64 words = sentence.split()
65 i = 1 # Start from index 1 to skip sentence-initial capitalization
66
67 while i < len(words):
68 # Clean word of punctuation
69 word1 = re.sub(r'[^\w\']+$', '', words[i])
70 word1_clean = re.sub(r'^[^\w\']+', '', word1)
71
72 # Check if previous word ends with opening quote (indicating quoted sentence start)
73 is_quote_start = False
74 if i > 0:
75 prev_word = words[i - 1]
76 # Check if previous word ends with quote marks
77 if prev_word and prev_word[-1] in ['"', '"', '"', '\'']:
78 is_quote_start = True
79
80 # Check if capitalized and not an excluded word or quote-initial word
81 if word1_clean and word1_clean[0].isupper() and word1_clean not in excluded_words and not is_quote_start:
82 # Check for 2-word proper noun
83 if i + 1 < len(words):
84 word2 = re.sub(r'[^\w\']+$', '', words[i + 1])
85 word2_clean = re.sub(r'^[^\w\']+', '', word2)
86 if word2_clean and word2_clean[0].isupper() and word2_clean not in excluded_words:
87 # Found 2-word proper noun
88 proper_nouns.add(f"{word1_clean} {word2_clean}")
89 i += 2
90 continue
91
92 # Single-word proper noun
93 proper_nouns.add(word1_clean)
94 i += 1
95
96 return proper_nouns
97
98def read_and_count_messages(file_path):
99 """Read the JSON file and return the array and its count."""
100 with open(file_path, 'r') as f:
101 data = json.load(f)
102
103 count = len(data)
104 print(f"Total messages in dataset: {count}")
105
106 return data, count
107
108# Start timing
109start_time = time.time()
110
111# Read the dataset
112messages, message_count = read_and_count_messages(dataset_path)
113
114# Initialize tiktoken encoder
115print("Initializing tokenizer...")
116encoder = tiktoken.encoding_for_model("gpt-4")
117
118# Create dictionary of sections and process messages
119print("Processing messages and extracting words...")
120word_pattern = re.compile(r"\b[\w']+\b")
121sections = {}
122all_messages = []
123word_dict = {}
124total_words = 0
125all_unique_words = set()
126
127for i, message in enumerate(messages):
128 section_title = message['section_title']
129 if section_title not in sections:
130 sections[section_title] = Section(title=section_title)
131
132 content = message['content']
133 speaker = message['speaker']
134 msg = Message(content=content, speaker=speaker, section=sections[section_title])
135 sections[section_title].messages.append(msg)
136 all_messages.append(msg)
137
138 # Extract proper nouns
139 proper_nouns = extract_proper_nouns(content)
140 sections[section_title].proper_nouns.update(proper_nouns)
141
142 # Build set of words to exclude (all words from proper nouns, lowercased)
143 words_to_exclude = set()
144 for proper_noun in proper_nouns:
145 words_to_exclude.update(word.lower() for word in proper_noun.split())
146
147 # Extract words from this message
148 words = word_pattern.findall(content.lower())
149 # Filter out numbers less than 100 and proper noun words
150 words = [w for w in words if not (w.isdigit() and int(w) < 100) and w not in words_to_exclude]
151 word_count = len(words)
152 char_count = len(content)
153 token_count = len(encoder.encode(content))
154 total_words += word_count
155 sections[section_title].total_words += word_count
156 sections[section_title].total_chars += char_count
157 sections[section_title].total_tokens += token_count
158 sections[section_title].words.update(words)
159 all_unique_words.update(words)
160
161 # Track which messages contain each word
162 unique_message_words = set(words)
163 for word in unique_message_words:
164 if word not in word_dict:
165 word_dict[word] = Word(word=word)
166 word_dict[word].messages.append(msg)
167
168print(f"Processing complete.")
169
170# Filter words based on maxMsgPct
171print(f"\nFiltering words with message percentage > {maxMsgPct * 100:.1f}%...")
172total_message_count = len(all_messages)
173max_messages_threshold = maxMsgPct * total_message_count
174
175# Filter overall word_dict
176words_before_filter = len(word_dict)
177filtered_word_dict = {word: word_obj for word, word_obj in word_dict.items()
178 if len(word_obj.messages) < max_messages_threshold}
179words_filtered = words_before_filter - len(filtered_word_dict)
180print(f" Filtered {words_filtered} words from overall word list ({words_before_filter} -> {len(filtered_word_dict)})")
181
182# Filter each section's word set
183for section in sections.values():
184 words_to_remove = {word for word in section.words
185 if word not in filtered_word_dict}
186 section.words -= words_to_remove
187
188# Update word_dict to filtered version
189word_dict = filtered_word_dict
190all_unique_words = set(word_dict.keys())
191
192print(f"Filtering complete.")
193
194# Print statistics
195section_count = len(sections)
196total_messages = sum(len(section.messages) for section in sections.values())
197avg_messages_per_section = total_messages / section_count if section_count > 0 else 0
198avg_words_per_section = sum(len(section.words) for section in sections.values()) / section_count if section_count > 0 else 0
199
200# Calculate word density ratio for each section
201word_density_ratios = []
202for section in sections.values():
203 if section.total_words > 0:
204 ratio = len(section.words) / section.total_words
205 word_density_ratios.append(ratio)
206
207avg_word_density = sum(word_density_ratios) / len(word_density_ratios) if word_density_ratios else 0
208
209total_message_references = sum(len(word.messages) for word in word_dict.values())
210avg_messages_per_word = total_message_references / len(word_dict) if word_dict else 0
211
212# Calculate token statistics
213import statistics
214token_counts = [section.total_tokens for section in sections.values()]
215median_tokens_per_section = statistics.median(token_counts) if token_counts else 0
216mean_tokens_per_section = statistics.mean(token_counts) if token_counts else 0
217
218# Calculate proper noun statistics
219all_unique_proper_nouns = set()
220for section in sections.values():
221 all_unique_proper_nouns.update(section.proper_nouns)
222
223avg_proper_nouns_per_section = sum(len(section.proper_nouns) for section in sections.values()) / section_count if section_count > 0 else 0
224
225print(f"Total sections: {section_count}")
226print(f"Average messages per section: {avg_messages_per_section:.2f}")
227print(f"Total words: {total_words}")
228print(f"Unique words across all messages: {len(all_unique_words)}")
229print(f"Average unique words per section: {avg_words_per_section:.2f}")
230print(f"Total unique proper nouns across all messages: {len(all_unique_proper_nouns)}")
231print(f"Average unique proper nouns per section: {avg_proper_nouns_per_section:.2f}")
232print(f"Average word density ratio (unique/total per section): {avg_word_density * 100:.3f}%")
233print(f"Overall word density ratio (unique/total across all messages): {len(all_unique_words) / total_words * 100:.3f}%")
234print(f"Average messages per unique word: {avg_messages_per_word:.2f}")
235print(f"Mean tokens per section: {mean_tokens_per_section:.2f}")
236print(f"Median tokens per section: {median_tokens_per_section:.2f}")
237
238# Analyze message count distribution
239message_counts = [len(word.messages) for word in word_dict.values()]
240message_counts_sorted = sorted(message_counts, reverse=True)
241
242# Get top 5 words by message count
243top_words = sorted(word_dict.values(), key=lambda w: len(w.messages), reverse=True)[:5]
244print(f"\nTop 5 words by message count:")
245for i, word_obj in enumerate(top_words, 1):
246 print(f" {i}. '{word_obj.word}': {len(word_obj.messages)} messages")
247
248# Calculate statistics
249import statistics
250median_messages = statistics.median(message_counts) if message_counts else 0
251max_messages = max(message_counts) if message_counts else 0
252min_messages = min(message_counts) if message_counts else 0
253
254print(f"\nMessage count statistics per word:")
255print(f" Min: {min_messages}")
256print(f" Median: {median_messages:.2f}")
257print(f" Max: {max_messages}")
258
259# Count words appearing in 4 or fewer messages
260words_4_or_fewer = sum(1 for count in message_counts if count <= 4)
261percentage_4_or_fewer = (words_4_or_fewer / len(message_counts) * 100) if message_counts else 0
262print(f" Words appearing in ≤4 messages: {words_4_or_fewer} ({percentage_4_or_fewer:.2f}%)")
263
264# Count words appearing in 16 or fewer messages
265words_16_or_fewer = sum(1 for count in message_counts if count <= 16)
266percentage_16_or_fewer = (words_16_or_fewer / len(message_counts) * 100) if message_counts else 0
267print(f" Words appearing in ≤16 messages: {words_16_or_fewer} ({percentage_16_or_fewer:.2f}%)")
268
269# Plot distribution of messages per word
270print("\nCreating distribution plot...")
271message_counts = [len(word.messages) for word in word_dict.values()]
272
273plt.figure(figsize=(12, 6))
274plt.hist(message_counts, bins=50, edgecolor='black', alpha=0.7)
275plt.xlabel('Number of Messages per Word')
276plt.ylabel('Number of Words')
277plt.title('Distribution of Message Counts per Unique Word')
278plt.yscale('log') # Use log scale for y-axis since distribution is likely skewed
279plt.grid(True, alpha=0.3)
280plt.tight_layout()
281plt.savefig('word_message_distribution.png', dpi=150)
282print(f"Distribution plot saved to word_message_distribution.png")
283
284# Plot distribution for words with 50 or fewer messages
285message_counts_50_or_less = [count for count in message_counts if count <= 50]
286plt.figure(figsize=(12, 6))
287plt.hist(message_counts_50_or_less, bins=50, edgecolor='black', alpha=0.7)
288plt.xlabel('Number of Messages per Word')
289plt.ylabel('Number of Words')
290plt.title('Distribution of Message Counts per Unique Word (≤50 messages)')
291plt.yscale('log')
292plt.grid(True, alpha=0.3)
293plt.tight_layout()
294plt.savefig('word_message_distribution_50_or_less.png', dpi=150)
295print(f"Zoomed distribution plot saved to word_message_distribution_50_or_less.png")
296
297# Write data to output file
298print(f"\nWriting data to {output_file}...")
299output_data = []
300for section in sections.values():
301 section_data = {
302 "title": section.title,
303 "messages": [
304 {
305 "content": msg.content,
306 "speaker": msg.speaker
307 }
308 for msg in section.messages
309 ],
310 "words": sorted(list(section.words)),
311 "proper_nouns": sorted(list(section.proper_nouns)),
312 "word_count": section.total_words,
313 "char_count": section.total_chars,
314 "token_count": section.total_tokens
315 }
316 output_data.append(section_data)
317
318with open(output_file, 'w') as f:
319 json.dump(output_data, f, indent=2)
320print(f"Data written to {output_file}")
321
322# Print elapsed time
323elapsed_time = time.time() - start_time
324print(f"\nTotal processing time: {elapsed_time:.2f} seconds")