microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d18049abeb7169f02b910796fbcc722e153dc3fb

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/fineTuning/unsloth/baseExtract.py

255lines · modecode

1from argparse import ArgumentParser
2import sys
3parser = ArgumentParser(description="Extract keywords from dataset using NLTK-RAKE.")
4parser.add_argument("--dataset_path", type=str, default='/data/npr/npr_chunks_no_embedding.json',
5help="Path to the dataset file.")
6# output_file
7parser.add_argument("--output_file", type=str, default='reformatted_data.txt',
8help="Path to the output file.")
9parser.add_argument("--maxMsgPct", type=float, default=0.005,
10help="Maximum message percentage threshold for filtering words (0.0-1.0). Words appearing in more than this percentage of messages will be filtered out. Default: 0.1")
11args = parser.parse_args(sys.argv[1:])
12dataset_path = args.dataset_path
13output_file = args.output_file
14maxMsgPct = args.maxMsgPct
15
16# Validate maxMsgPct
17if not 0.0 <= maxMsgPct <= 1.0:
18 print(f"Error: maxMsgPct must be between 0.0 and 1.0, got {maxMsgPct}")
19 sys.exit(1)
20
21import json
22import re
23import time
24from dataclasses import dataclass, field
25import matplotlib.pyplot as plt
26import tiktoken
27
28@dataclass
29class Message:
30 content: str
31 speaker: str
32 section: 'Section' = None
33
34@dataclass
35class Section:
36 title: str
37 messages: list[Message] = field(default_factory=list)
38 words: set[str] = field(default_factory=set)
39 total_words: int = 0
40 total_chars: int = 0
41 total_tokens: int = 0
42
43@dataclass
44class Word:
45 word: str
46 messages: list[Message] = field(default_factory=list)
47
48def read_and_count_messages(file_path):
49 """Read the JSON file and return the array and its count."""
50 with open(file_path, 'r') as f:
51 data = json.load(f)
52
53 count = len(data)
54 print(f"Total messages in dataset: {count}")
55
56 return data, count
57
58# Start timing
59start_time = time.time()
60
61# Read the dataset
62messages, message_count = read_and_count_messages(dataset_path)
63
64# Initialize tiktoken encoder
65print("Initializing tokenizer...")
66encoder = tiktoken.encoding_for_model("gpt-4")
67
68# Create dictionary of sections and process messages
69print("Processing messages and extracting words...")
70word_pattern = re.compile(r"\b[\w']+\b")
71sections = {}
72all_messages = []
73word_dict = {}
74total_words = 0
75all_unique_words = set()
76
77for i, message in enumerate(messages):
78 section_title = message['section_title']
79 if section_title not in sections:
80 sections[section_title] = Section(title=section_title)
81
82 content = message['content']
83 speaker = message['speaker']
84 msg = Message(content=content, speaker=speaker, section=sections[section_title])
85 sections[section_title].messages.append(msg)
86 all_messages.append(msg)
87
88 # Extract words from this message
89 words = word_pattern.findall(content.lower())
90 # Filter out numbers less than 100
91 words = [w for w in words if not (w.isdigit() and int(w) < 100)]
92 word_count = len(words)
93 char_count = len(content)
94 token_count = len(encoder.encode(content))
95 total_words += word_count
96 sections[section_title].total_words += word_count
97 sections[section_title].total_chars += char_count
98 sections[section_title].total_tokens += token_count
99 sections[section_title].words.update(words)
100 all_unique_words.update(words)
101
102 # Track which messages contain each word
103 unique_message_words = set(words)
104 for word in unique_message_words:
105 if word not in word_dict:
106 word_dict[word] = Word(word=word)
107 word_dict[word].messages.append(msg)
108
109print(f"Processing complete.")
110
111# Filter words based on maxMsgPct
112print(f"\nFiltering words with message percentage > {maxMsgPct * 100:.1f}%...")
113total_message_count = len(all_messages)
114max_messages_threshold = maxMsgPct * total_message_count
115
116# Filter overall word_dict
117words_before_filter = len(word_dict)
118filtered_word_dict = {word: word_obj for word, word_obj in word_dict.items()
119 if len(word_obj.messages) < max_messages_threshold}
120words_filtered = words_before_filter - len(filtered_word_dict)
121print(f" Filtered {words_filtered} words from overall word list ({words_before_filter} -> {len(filtered_word_dict)})")
122
123# Filter each section's word set
124for section in sections.values():
125 words_to_remove = {word for word in section.words
126 if word not in filtered_word_dict}
127 section.words -= words_to_remove
128
129# Update word_dict to filtered version
130word_dict = filtered_word_dict
131all_unique_words = set(word_dict.keys())
132
133print(f"Filtering complete.")
134
135# Print statistics
136section_count = len(sections)
137total_messages = sum(len(section.messages) for section in sections.values())
138avg_messages_per_section = total_messages / section_count if section_count > 0 else 0
139avg_words_per_section = sum(len(section.words) for section in sections.values()) / section_count if section_count > 0 else 0
140
141# Calculate word density ratio for each section
142word_density_ratios = []
143for section in sections.values():
144 if section.total_words > 0:
145 ratio = len(section.words) / section.total_words
146 word_density_ratios.append(ratio)
147
148avg_word_density = sum(word_density_ratios) / len(word_density_ratios) if word_density_ratios else 0
149
150total_message_references = sum(len(word.messages) for word in word_dict.values())
151avg_messages_per_word = total_message_references / len(word_dict) if word_dict else 0
152
153# Calculate token statistics
154import statistics
155token_counts = [section.total_tokens for section in sections.values()]
156median_tokens_per_section = statistics.median(token_counts) if token_counts else 0
157mean_tokens_per_section = statistics.mean(token_counts) if token_counts else 0
158
159print(f"Total sections: {section_count}")
160print(f"Average messages per section: {avg_messages_per_section:.2f}")
161print(f"Total words: {total_words}")
162print(f"Unique words across all messages: {len(all_unique_words)}")
163print(f"Average unique words per section: {avg_words_per_section:.2f}")
164print(f"Average word density ratio (unique/total per section): {avg_word_density * 100:.3f}%")
165print(f"Overall word density ratio (unique/total across all messages): {len(all_unique_words) / total_words * 100:.3f}%")
166print(f"Average messages per unique word: {avg_messages_per_word:.2f}")
167print(f"Mean tokens per section: {mean_tokens_per_section:.2f}")
168print(f"Median tokens per section: {median_tokens_per_section:.2f}")
169
170# Analyze message count distribution
171message_counts = [len(word.messages) for word in word_dict.values()]
172message_counts_sorted = sorted(message_counts, reverse=True)
173
174# Get top 5 words by message count
175top_words = sorted(word_dict.values(), key=lambda w: len(w.messages), reverse=True)[:5]
176print(f"\nTop 5 words by message count:")
177for i, word_obj in enumerate(top_words, 1):
178 print(f" {i}. '{word_obj.word}': {len(word_obj.messages)} messages")
179
180# Calculate statistics
181import statistics
182median_messages = statistics.median(message_counts) if message_counts else 0
183max_messages = max(message_counts) if message_counts else 0
184min_messages = min(message_counts) if message_counts else 0
185
186print(f"\nMessage count statistics per word:")
187print(f" Min: {min_messages}")
188print(f" Median: {median_messages:.2f}")
189print(f" Max: {max_messages}")
190
191# Count words appearing in 4 or fewer messages
192words_4_or_fewer = sum(1 for count in message_counts if count <= 4)
193percentage_4_or_fewer = (words_4_or_fewer / len(message_counts) * 100) if message_counts else 0
194print(f" Words appearing in ≤4 messages: {words_4_or_fewer} ({percentage_4_or_fewer:.2f}%)")
195
196# Count words appearing in 16 or fewer messages
197words_16_or_fewer = sum(1 for count in message_counts if count <= 16)
198percentage_16_or_fewer = (words_16_or_fewer / len(message_counts) * 100) if message_counts else 0
199print(f" Words appearing in ≤16 messages: {words_16_or_fewer} ({percentage_16_or_fewer:.2f}%)")
200
201# Plot distribution of messages per word
202print("\nCreating distribution plot...")
203message_counts = [len(word.messages) for word in word_dict.values()]
204
205plt.figure(figsize=(12, 6))
206plt.hist(message_counts, bins=50, edgecolor='black', alpha=0.7)
207plt.xlabel('Number of Messages per Word')
208plt.ylabel('Number of Words')
209plt.title('Distribution of Message Counts per Unique Word')
210plt.yscale('log') # Use log scale for y-axis since distribution is likely skewed
211plt.grid(True, alpha=0.3)
212plt.tight_layout()
213plt.savefig('word_message_distribution.png', dpi=150)
214print(f"Distribution plot saved to word_message_distribution.png")
215
216# Plot distribution for words with 50 or fewer messages
217message_counts_50_or_less = [count for count in message_counts if count <= 50]
218plt.figure(figsize=(12, 6))
219plt.hist(message_counts_50_or_less, bins=50, edgecolor='black', alpha=0.7)
220plt.xlabel('Number of Messages per Word')
221plt.ylabel('Number of Words')
222plt.title('Distribution of Message Counts per Unique Word (≤50 messages)')
223plt.yscale('log')
224plt.grid(True, alpha=0.3)
225plt.tight_layout()
226plt.savefig('word_message_distribution_50_or_less.png', dpi=150)
227print(f"Zoomed distribution plot saved to word_message_distribution_50_or_less.png")
228
229# Write data to output file
230print(f"\nWriting data to {output_file}...")
231output_data = []
232for section in sections.values():
233 section_data = {
234 "title": section.title,
235 "messages": [
236 {
237 "content": msg.content,
238 "speaker": msg.speaker
239 }
240 for msg in section.messages
241 ],
242 "words": sorted(list(section.words)),
243 "word_count": section.total_words,
244 "char_count": section.total_chars,
245 "token_count": section.total_tokens
246 }
247 output_data.append(section_data)
248
249with open(output_file, 'w') as f:
250 json.dump(output_data, f, indent=2)
251print(f"Data written to {output_file}")
252
253# Print elapsed time
254elapsed_time = time.time() - start_time
255print(f"\nTotal processing time: {elapsed_time:.2f} seconds")