openai/openai-python
Publicmirrored from https://github.com/openai/openai-pythonAvailable
openai/validators.py
698lines · modecode
| 1 | import os |
| 2 | import sys |
| 3 | import pandas as pd |
| 4 | |
| 5 | from typing import NamedTuple, Optional, Callable, Any |
| 6 | |
| 7 | |
| 8 | class Remediation(NamedTuple): |
| 9 | name: str |
| 10 | immediate_msg: Optional[str] = None |
| 11 | necessary_msg: Optional[str] = None |
| 12 | necessary_fn: Optional[Callable[[Any], Any]] = None |
| 13 | optional_msg: Optional[str] = None |
| 14 | optional_fn: Optional[Callable[[Any], Any]] = None |
| 15 | error_msg: Optional[str] = None |
| 16 | |
| 17 | |
| 18 | def num_examples_validator(df): |
| 19 | """ |
| 20 | This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100. |
| 21 | """ |
| 22 | MIN_EXAMPLES = 100 |
| 23 | optional_suggestion = ( |
| 24 | "" |
| 25 | if len(df) >= MIN_EXAMPLES |
| 26 | else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples" |
| 27 | ) |
| 28 | immediate_msg = ( |
| 29 | f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}" |
| 30 | ) |
| 31 | return Remediation(name="num_examples", immediate_msg=immediate_msg) |
| 32 | |
| 33 | |
| 34 | def necessary_column_validator(df, necessary_column): |
| 35 | """ |
| 36 | This validator will ensure that the necessary column is present in the dataframe. |
| 37 | """ |
| 38 | |
| 39 | def lower_case_column(df, column): |
| 40 | cols = [c for c in df.columns if c.lower() == column] |
| 41 | df.rename(columns={cols[0]: column.lower()}, inplace=True) |
| 42 | return df |
| 43 | |
| 44 | immediate_msg = None |
| 45 | necessary_fn = None |
| 46 | necessary_msg = None |
| 47 | error_msg = None |
| 48 | |
| 49 | if necessary_column not in df.columns: |
| 50 | if necessary_column in [c.lower() for c in df.columns]: |
| 51 | |
| 52 | def lower_case_column_creator(df): |
| 53 | return lower_case_column(df, necessary_column) |
| 54 | |
| 55 | necessary_fn = lower_case_column_creator |
| 56 | immediate_msg = ( |
| 57 | f"\n- The `{necessary_column}` column/key should be lowercase" |
| 58 | ) |
| 59 | necessary_msg = f"Lower case column name to `{necessary_column}`" |
| 60 | else: |
| 61 | error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry" |
| 62 | |
| 63 | return Remediation( |
| 64 | name="necessary_column", |
| 65 | immediate_msg=immediate_msg, |
| 66 | necessary_msg=necessary_msg, |
| 67 | necessary_fn=necessary_fn, |
| 68 | error_msg=error_msg, |
| 69 | ) |
| 70 | |
| 71 | |
| 72 | def additional_column_validator(df): |
| 73 | """ |
| 74 | This validator will remove additional columns from the dataframe. |
| 75 | """ |
| 76 | additional_columns = [] |
| 77 | necessary_msg = None |
| 78 | immediate_msg = None |
| 79 | necessary_fn = None |
| 80 | if len(df.columns) > 2: |
| 81 | additional_columns = [ |
| 82 | c for c in df.columns if c not in ["prompt", "completion"] |
| 83 | ] |
| 84 | warn_message = "" |
| 85 | for ac in additional_columns: |
| 86 | dups = [c for c in additional_columns if ac in c] |
| 87 | if len(dups) > 0: |
| 88 | warn_message += f"\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file." |
| 89 | immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}" |
| 90 | necessary_msg = f"Remove additional columns/keys: {additional_columns}" |
| 91 | |
| 92 | def necessary_fn(x): |
| 93 | return x[["prompt", "completion"]] |
| 94 | |
| 95 | return Remediation( |
| 96 | name="additional_column", |
| 97 | immediate_msg=immediate_msg, |
| 98 | necessary_msg=necessary_msg, |
| 99 | necessary_fn=necessary_fn, |
| 100 | ) |
| 101 | |
| 102 | |
| 103 | def non_empty_completion_validator(df): |
| 104 | """ |
| 105 | This validator will ensure that no completion is empty. |
| 106 | """ |
| 107 | necessary_msg = None |
| 108 | necessary_fn = None |
| 109 | immediate_msg = None |
| 110 | |
| 111 | if ( |
| 112 | df["completion"].apply(lambda x: x == "").any() |
| 113 | or df["completion"].isnull().any() |
| 114 | ): |
| 115 | empty_rows = (df["completion"] == "") | (df["completion"].isnull()) |
| 116 | empty_indexes = df.reset_index().index[empty_rows].tolist() |
| 117 | immediate_msg = f"\n- `completion` column/key should not contain empty strings. These are rows: {empty_indexes}" |
| 118 | |
| 119 | def necessary_fn(x): |
| 120 | return x[x["completion"] != ""].dropna(subset=["completion"]) |
| 121 | |
| 122 | necessary_msg = f"Remove {len(empty_indexes)} rows with empty completions" |
| 123 | return Remediation( |
| 124 | name="empty_completion", |
| 125 | immediate_msg=immediate_msg, |
| 126 | necessary_msg=necessary_msg, |
| 127 | necessary_fn=necessary_fn, |
| 128 | ) |
| 129 | |
| 130 | |
| 131 | def duplicated_rows_validator(df): |
| 132 | """ |
| 133 | This validator will suggest to the user to remove duplicate rows if they exist. |
| 134 | """ |
| 135 | duplicated_rows = df.duplicated(subset=["prompt", "completion"]) |
| 136 | duplicated_indexes = df.reset_index().index[duplicated_rows].tolist() |
| 137 | immediate_msg = None |
| 138 | optional_msg = None |
| 139 | optional_fn = None |
| 140 | |
| 141 | if len(duplicated_indexes) > 0: |
| 142 | immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated prompt-completion pairs. These are rows: {duplicated_indexes}" |
| 143 | optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows" |
| 144 | |
| 145 | def optional_fn(x): |
| 146 | return x.drop_duplicates(subset=["prompt", "completion"]) |
| 147 | |
| 148 | return Remediation( |
| 149 | name="duplicated_rows", |
| 150 | immediate_msg=immediate_msg, |
| 151 | optional_msg=optional_msg, |
| 152 | optional_fn=optional_fn, |
| 153 | ) |
| 154 | |
| 155 | |
| 156 | def common_prompt_suffix_validator(df): |
| 157 | """ |
| 158 | This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation. |
| 159 | """ |
| 160 | error_msg = None |
| 161 | immediate_msg = None |
| 162 | optional_msg = None |
| 163 | optional_fn = None |
| 164 | |
| 165 | # Find a suffix which is not contained within the prompt otherwise |
| 166 | suggested_suffix = "\n\n### =>\n\n" |
| 167 | suffix_options = [ |
| 168 | " ->", |
| 169 | "\n\n###\n\n", |
| 170 | "\n\n===\n\n", |
| 171 | "\n\n---\n\n", |
| 172 | "\n\n===>\n\n", |
| 173 | "\n\n--->\n\n", |
| 174 | ] |
| 175 | for suffix_option in suffix_options: |
| 176 | if suffix_option == " ->": |
| 177 | if df.prompt.str.contains("\n").any(): |
| 178 | continue |
| 179 | if df.prompt.str.contains(suffix_option, regex=False).any(): |
| 180 | continue |
| 181 | suggested_suffix = suffix_option |
| 182 | break |
| 183 | display_suggested_suffix = suggested_suffix.replace("\n", "\\n") |
| 184 | |
| 185 | ft_type = infer_task_type(df) |
| 186 | if ft_type == "open-ended generation": |
| 187 | return Remediation(name="common_suffix") |
| 188 | |
| 189 | def add_suffix(x, suffix): |
| 190 | x["prompt"] += suffix |
| 191 | return x |
| 192 | |
| 193 | common_suffix = get_common_xfix(df.prompt, xfix="suffix") |
| 194 | if (df.prompt == common_suffix).all(): |
| 195 | error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different" |
| 196 | return Remediation(name="common_suffix", error_msg=error_msg) |
| 197 | |
| 198 | if common_suffix != "": |
| 199 | common_suffix_new_line_handled = common_suffix.replace("\n", "\\n") |
| 200 | immediate_msg = ( |
| 201 | f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`" |
| 202 | ) |
| 203 | if len(common_suffix) > 10: |
| 204 | immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`" |
| 205 | if ( |
| 206 | df.prompt.str[: -len(common_suffix)] |
| 207 | .str.contains(common_suffix, regex=False) |
| 208 | .any() |
| 209 | ): |
| 210 | immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix" |
| 211 | |
| 212 | else: |
| 213 | immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See `Fine Tuning How to Guide` for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty" |
| 214 | |
| 215 | if common_suffix == "": |
| 216 | optional_msg = ( |
| 217 | f"Add a suffix separator `{display_suggested_suffix}` to all prompts" |
| 218 | ) |
| 219 | |
| 220 | def optional_fn(x): |
| 221 | return add_suffix(x, suggested_suffix) |
| 222 | |
| 223 | return Remediation( |
| 224 | name="common_completion_suffix", |
| 225 | immediate_msg=immediate_msg, |
| 226 | optional_msg=optional_msg, |
| 227 | optional_fn=optional_fn, |
| 228 | error_msg=error_msg, |
| 229 | ) |
| 230 | |
| 231 | |
| 232 | def common_prompt_prefix_validator(df): |
| 233 | """ |
| 234 | This validator will suggest to remove a common prefix from the prompt if a long one exist. |
| 235 | """ |
| 236 | MAX_PREFIX_LEN = 12 |
| 237 | |
| 238 | immediate_msg = None |
| 239 | optional_msg = None |
| 240 | optional_fn = None |
| 241 | |
| 242 | common_prefix = get_common_xfix(df.prompt, xfix="prefix") |
| 243 | if common_prefix == "": |
| 244 | return Remediation(name="common_prefix") |
| 245 | |
| 246 | def remove_common_prefix(x, prefix): |
| 247 | x["prompt"] = x["prompt"].str[len(prefix) :] |
| 248 | return x |
| 249 | |
| 250 | if (df.prompt == common_prefix).all(): |
| 251 | # already handled by common_suffix_validator |
| 252 | return Remediation(name="common_prefix") |
| 253 | |
| 254 | if common_prefix != "": |
| 255 | immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`" |
| 256 | if MAX_PREFIX_LEN < len(common_prefix): |
| 257 | immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion" |
| 258 | optional_msg = f"Remove prefix `{common_prefix}` from all prompts" |
| 259 | |
| 260 | def optional_fn(x): |
| 261 | return remove_common_prefix(x, common_prefix) |
| 262 | |
| 263 | return Remediation( |
| 264 | name="common_prompt_prefix", |
| 265 | immediate_msg=immediate_msg, |
| 266 | optional_msg=optional_msg, |
| 267 | optional_fn=optional_fn, |
| 268 | ) |
| 269 | |
| 270 | |
| 271 | def common_completion_prefix_validator(df): |
| 272 | """ |
| 273 | This validator will suggest to remove a common prefix from the completion if a long one exist. |
| 274 | """ |
| 275 | MAX_PREFIX_LEN = 5 |
| 276 | |
| 277 | common_prefix = get_common_xfix(df.completion, xfix="prefix") |
| 278 | ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " " |
| 279 | if len(common_prefix) < MAX_PREFIX_LEN: |
| 280 | return Remediation(name="common_prefix") |
| 281 | |
| 282 | def remove_common_prefix(x, prefix, ws_prefix): |
| 283 | x["completion"] = x["completion"].str[len(prefix) :] |
| 284 | if ws_prefix: |
| 285 | # keep the single whitespace as prefix |
| 286 | x["completion"] = " " + x["completion"] |
| 287 | return x |
| 288 | |
| 289 | if (df.completion == common_prefix).all(): |
| 290 | # already handled by common_suffix_validator |
| 291 | return Remediation(name="common_prefix") |
| 292 | |
| 293 | immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix" |
| 294 | optional_msg = f"Remove prefix `{common_prefix}` from all completions" |
| 295 | |
| 296 | def optional_fn(x): |
| 297 | return remove_common_prefix(x, common_prefix, ws_prefix) |
| 298 | |
| 299 | return Remediation( |
| 300 | name="common_completion_prefix", |
| 301 | immediate_msg=immediate_msg, |
| 302 | optional_msg=optional_msg, |
| 303 | optional_fn=optional_fn, |
| 304 | ) |
| 305 | |
| 306 | |
| 307 | def common_completion_suffix_validator(df): |
| 308 | """ |
| 309 | This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation. |
| 310 | """ |
| 311 | error_msg = None |
| 312 | immediate_msg = None |
| 313 | optional_msg = None |
| 314 | optional_fn = None |
| 315 | |
| 316 | ft_type = infer_task_type(df) |
| 317 | if ft_type == "open-ended generation" or ft_type == "classification": |
| 318 | return Remediation(name="common_suffix") |
| 319 | |
| 320 | common_suffix = get_common_xfix(df.completion, xfix="suffix") |
| 321 | if (df.completion == common_suffix).all(): |
| 322 | error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`" |
| 323 | return Remediation(name="common_suffix", error_msg=error_msg) |
| 324 | |
| 325 | # Find a suffix which is not contained within the completion otherwise |
| 326 | suggested_suffix = " [END]" |
| 327 | suffix_options = [ |
| 328 | "\n", |
| 329 | ".", |
| 330 | " END", |
| 331 | "***", |
| 332 | "+++", |
| 333 | "&&&", |
| 334 | "$$$", |
| 335 | "@@@", |
| 336 | "%%%", |
| 337 | ] |
| 338 | for suffix_option in suffix_options: |
| 339 | if df.completion.str.contains(suffix_option, regex=False).any(): |
| 340 | continue |
| 341 | suggested_suffix = suffix_option |
| 342 | break |
| 343 | display_suggested_suffix = suggested_suffix.replace("\n", "\\n") |
| 344 | |
| 345 | def add_suffix(x, suffix): |
| 346 | x["completion"] += suffix |
| 347 | return x |
| 348 | |
| 349 | if common_suffix != "": |
| 350 | common_suffix_new_line_handled = common_suffix.replace("\n", "\\n") |
| 351 | immediate_msg = ( |
| 352 | f"\n- All completions end with suffix `{common_suffix_new_line_handled}`" |
| 353 | ) |
| 354 | if len(common_suffix) > 10: |
| 355 | immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`" |
| 356 | if ( |
| 357 | df.completion.str[: -len(common_suffix)] |
| 358 | .str.contains(common_suffix, regex=False) |
| 359 | .any() |
| 360 | ): |
| 361 | immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending" |
| 362 | |
| 363 | else: |
| 364 | immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See `Fine Tuning How to Guide` for more detail and examples." |
| 365 | |
| 366 | if common_suffix == "": |
| 367 | optional_msg = ( |
| 368 | f"Add a suffix ending `{display_suggested_suffix}` to all completions" |
| 369 | ) |
| 370 | |
| 371 | def optional_fn(x): |
| 372 | return add_suffix(x, suggested_suffix) |
| 373 | |
| 374 | return Remediation( |
| 375 | name="common_completion_suffix", |
| 376 | immediate_msg=immediate_msg, |
| 377 | optional_msg=optional_msg, |
| 378 | optional_fn=optional_fn, |
| 379 | error_msg=error_msg, |
| 380 | ) |
| 381 | |
| 382 | |
| 383 | def completions_space_start_validator(df): |
| 384 | """ |
| 385 | This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization. |
| 386 | """ |
| 387 | |
| 388 | def add_space_start(x): |
| 389 | x["completion"] = x["completion"].apply( |
| 390 | lambda x: ("" if x[0] == " " else " ") + x |
| 391 | ) |
| 392 | return x |
| 393 | |
| 394 | optional_msg = None |
| 395 | optional_fn = None |
| 396 | immediate_msg = None |
| 397 | |
| 398 | if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ": |
| 399 | immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See `Fine Tuning How to Guide` for more details" |
| 400 | optional_msg = "Add a whitespace character to the beginning of the completion" |
| 401 | optional_fn = add_space_start |
| 402 | return Remediation( |
| 403 | name="completion_space_start", |
| 404 | immediate_msg=immediate_msg, |
| 405 | optional_msg=optional_msg, |
| 406 | optional_fn=optional_fn, |
| 407 | ) |
| 408 | |
| 409 | |
| 410 | def lower_case_validator(df, column): |
| 411 | """ |
| 412 | This validator will suggest to lowercase the column values, if more than a third of letters are uppercase. |
| 413 | """ |
| 414 | |
| 415 | def lower_case(x): |
| 416 | x[column] = x[column].str.lower() |
| 417 | return x |
| 418 | |
| 419 | count_upper = ( |
| 420 | df[column] |
| 421 | .apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())) |
| 422 | .sum() |
| 423 | ) |
| 424 | count_lower = ( |
| 425 | df[column] |
| 426 | .apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())) |
| 427 | .sum() |
| 428 | ) |
| 429 | |
| 430 | if count_upper * 2 > count_lower: |
| 431 | return Remediation( |
| 432 | name="lower_case", |
| 433 | immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See `Fine Tuning How to Guide` for more details", |
| 434 | optional_msg=f"Lowercase all your data in column/key `{column}`", |
| 435 | optional_fn=lower_case, |
| 436 | ) |
| 437 | |
| 438 | |
| 439 | def read_any_format(fname): |
| 440 | """ |
| 441 | This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas. |
| 442 | - for .xlsx it will read the first sheet |
| 443 | - for .txt it will assume completions and split on newline |
| 444 | """ |
| 445 | remediation = None |
| 446 | necessary_msg = None |
| 447 | immediate_msg = None |
| 448 | error_msg = None |
| 449 | df = None |
| 450 | |
| 451 | if os.path.isfile(fname): |
| 452 | for ending, separator in [(".csv", ","), (".tsv", "\t")]: |
| 453 | if fname.lower().endswith(ending): |
| 454 | immediate_msg = f"\n- Based on your file extension, your file is formatted as a {ending[1:].upper()} file" |
| 455 | necessary_msg = ( |
| 456 | f"Your format `{ending[1:].upper()}` will be converted to `JSONL`" |
| 457 | ) |
| 458 | df = pd.read_csv(fname, sep=separator, dtype=str) |
| 459 | if fname.lower().endswith(".xlsx"): |
| 460 | immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file" |
| 461 | necessary_msg = "Your format `XLSX` will be converted to `JSONL`" |
| 462 | xls = pd.ExcelFile(fname) |
| 463 | sheets = xls.sheet_names |
| 464 | if len(sheets) > 1: |
| 465 | immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..." |
| 466 | df = pd.read_excel(fname, dtype=str) |
| 467 | if fname.lower().endswith(".txt"): |
| 468 | immediate_msg = "\n- Based on your file extension, you provided a text file" |
| 469 | necessary_msg = "Your format `TXT` will be converted to `JSONL`" |
| 470 | with open(fname, "r") as f: |
| 471 | content = f.read() |
| 472 | df = pd.DataFrame( |
| 473 | [["", line] for line in content.split("\n")], |
| 474 | columns=["prompt", "completion"], |
| 475 | dtype=str, |
| 476 | ) |
| 477 | if fname.lower().endswith("jsonl") or fname.lower().endswith("json"): |
| 478 | try: |
| 479 | df = pd.read_json(fname, lines=True, dtype=str) |
| 480 | except (ValueError, TypeError): |
| 481 | df = pd.read_json(fname, dtype=str) |
| 482 | immediate_msg = "\n- Your file appears to be in a .JSON format. Your file will be converted to JSONL format" |
| 483 | necessary_msg = "Your format `JSON` will be converted to `JSONL`" |
| 484 | |
| 485 | if df is None: |
| 486 | error_msg = ( |
| 487 | "Your file is not saved as a .CSV, .TSV, .XLSX, .TXT or .JSONL file." |
| 488 | ) |
| 489 | if "." in fname: |
| 490 | error_msg += ( |
| 491 | f" Your file `{fname}` appears to end with `.{fname.split('.')[1]}`" |
| 492 | ) |
| 493 | else: |
| 494 | error_msg += f" Your file `{fname}` does not appear to have a file ending. Please ensure your filename ends with one of the supported file endings." |
| 495 | else: |
| 496 | df.fillna("", inplace=True) |
| 497 | else: |
| 498 | error_msg = f"File {fname} does not exist." |
| 499 | |
| 500 | remediation = Remediation( |
| 501 | name="read_any_format", |
| 502 | necessary_msg=necessary_msg, |
| 503 | immediate_msg=immediate_msg, |
| 504 | error_msg=error_msg, |
| 505 | ) |
| 506 | return df, remediation |
| 507 | |
| 508 | |
| 509 | def format_inferrer_validator(df): |
| 510 | """ |
| 511 | This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification. |
| 512 | It will also suggest to use ada, --no_packing and explain train/validation split benefits. |
| 513 | """ |
| 514 | ft_type = infer_task_type(df) |
| 515 | immediate_msg = None |
| 516 | if ft_type == "classification": |
| 517 | immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`. You should also set the `--no_packing` parameter when fine-tuning\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training" |
| 518 | return Remediation(name="num_examples", immediate_msg=immediate_msg) |
| 519 | |
| 520 | |
| 521 | def apply_necessary_remediation(df, remediation): |
| 522 | """ |
| 523 | This function will apply a necessary remediation to a dataframe, or print an error message if one exists. |
| 524 | """ |
| 525 | if remediation.error_msg is not None: |
| 526 | sys.stderr.write( |
| 527 | f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting..." |
| 528 | ) |
| 529 | sys.exit(1) |
| 530 | if remediation.immediate_msg is not None: |
| 531 | sys.stdout.write(remediation.immediate_msg) |
| 532 | if remediation.necessary_fn is not None: |
| 533 | df = remediation.necessary_fn(df) |
| 534 | return df |
| 535 | |
| 536 | |
| 537 | def apply_optional_remediation(df, remediation): |
| 538 | """ |
| 539 | This function will apply an optional remediation to a dataframe, based on the user input. |
| 540 | """ |
| 541 | if remediation.optional_msg is not None: |
| 542 | if input(f"- [Recommended] {remediation.optional_msg} [Y/n]: ").lower() != "n": |
| 543 | df = remediation.optional_fn(df) |
| 544 | if remediation.necessary_msg is not None: |
| 545 | sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n") |
| 546 | return df |
| 547 | |
| 548 | |
| 549 | def write_out_file(df, fname, any_remediations): |
| 550 | """ |
| 551 | This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file. |
| 552 | For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set. |
| 553 | """ |
| 554 | ft_format = infer_task_type(df) |
| 555 | common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix") |
| 556 | common_completion_suffix = get_common_xfix(df.completion, xfix="suffix") |
| 557 | |
| 558 | split = False |
| 559 | if ft_format == "classification": |
| 560 | if ( |
| 561 | input( |
| 562 | "- [Recommended] Would you like to split into training and validation set? [Y/n]: " |
| 563 | ) |
| 564 | != "n" |
| 565 | ): |
| 566 | split = True |
| 567 | |
| 568 | packing_param = " --no_packing" if ft_format == "classification" else "" |
| 569 | common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n") |
| 570 | common_completion_suffix_new_line_handled = common_completion_suffix.replace( |
| 571 | "\n", "\\n" |
| 572 | ) |
| 573 | optional_ending_string = ( |
| 574 | f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.' |
| 575 | if len(common_completion_suffix_new_line_handled) > 0 |
| 576 | else "" |
| 577 | ) |
| 578 | |
| 579 | if not any_remediations: |
| 580 | sys.stdout.write( |
| 581 | f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{packing_param}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n' |
| 582 | ) |
| 583 | |
| 584 | elif ( |
| 585 | input( |
| 586 | "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: " |
| 587 | ).lower() |
| 588 | != "n" |
| 589 | ): |
| 590 | |
| 591 | suffixes = ["_train", "_valid"] if split else [""] |
| 592 | outfnames = [] |
| 593 | indices = None |
| 594 | for suffix in suffixes: |
| 595 | out_fname = fname.split(".")[0] + "_prepared" + suffix + ".jsonl" |
| 596 | |
| 597 | # check if file already exists, and if it does, add a number to the end |
| 598 | i = 0 |
| 599 | while True: |
| 600 | to_continue = False |
| 601 | # in case of train and test, make sure that the numbers will match |
| 602 | for suf in suffixes: |
| 603 | out_fname = ( |
| 604 | fname.split(".")[0] + "_prepared" + suf + f" ({i})" + ".jsonl" |
| 605 | ) |
| 606 | if i == 0: |
| 607 | out_fname = fname.split(".")[0] + "_prepared" + suf + ".jsonl" |
| 608 | i += 1 |
| 609 | if os.path.isfile(out_fname): |
| 610 | to_continue = True |
| 611 | if to_continue: |
| 612 | continue |
| 613 | break |
| 614 | |
| 615 | outfnames.append(out_fname) |
| 616 | if suffix == "_train": |
| 617 | MAX_VALID_EXAMPLES = 1000 |
| 618 | n = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8)) |
| 619 | df_out = df.sample(n=n, random_state=42) |
| 620 | indices = df_out.index |
| 621 | if suffix == "_valid": |
| 622 | df_out = df.drop(indices) |
| 623 | if suffix == "": |
| 624 | df_out = df |
| 625 | df_out[["prompt", "completion"]].to_json( |
| 626 | out_fname, lines=True, orient="records" |
| 627 | ) |
| 628 | |
| 629 | # Add -v VALID_FILE if we split the file into train / valid |
| 630 | files_string = ("s" if split else "") + " to `" + ("` and `".join(outfnames)) |
| 631 | valid_string = f' -v "{outfnames[1]}"' if split else "" |
| 632 | separator_reminder = ( |
| 633 | "" |
| 634 | if len(common_prompt_suffix_new_line_handled) == 0 |
| 635 | else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt." |
| 636 | ) |
| 637 | sys.stdout.write( |
| 638 | f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{outfnames[0]}"{valid_string}{packing_param}\n\n{separator_reminder}{optional_ending_string}\n' |
| 639 | ) |
| 640 | else: |
| 641 | sys.stdout.write("Aborting... did not write the file\n") |
| 642 | |
| 643 | |
| 644 | def infer_task_type(df): |
| 645 | """ |
| 646 | Infer the likely fine-tuning task type from the data |
| 647 | """ |
| 648 | CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class |
| 649 | if sum(df.prompt.str.len()) == 0: |
| 650 | return "open-ended generation" |
| 651 | |
| 652 | if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD: |
| 653 | return "classification" |
| 654 | |
| 655 | return "conditional generation" |
| 656 | |
| 657 | |
| 658 | def get_common_xfix(series, xfix="suffix"): |
| 659 | """ |
| 660 | Finds the longest common suffix or prefix of all the values in a series |
| 661 | """ |
| 662 | common_xfix = "" |
| 663 | while True: |
| 664 | common_xfixes = ( |
| 665 | series.str[-(len(common_xfix) + 1) :] |
| 666 | if xfix == "suffix" |
| 667 | else series.str[: len(common_xfix) + 1] |
| 668 | ) # first few or last few characters |
| 669 | if ( |
| 670 | common_xfixes.nunique() != 1 |
| 671 | ): # we found the character at which we don't have a unique xfix anymore |
| 672 | break |
| 673 | elif ( |
| 674 | common_xfix == common_xfixes.values[0] |
| 675 | ): # the entire first row is a prefix of every other row |
| 676 | break |
| 677 | else: # the first or last few characters are still common across all rows - let's try to add one more |
| 678 | common_xfix = common_xfixes.values[0] |
| 679 | return common_xfix |
| 680 | |
| 681 | |
| 682 | def get_validators(): |
| 683 | return [ |
| 684 | num_examples_validator, |
| 685 | lambda x: necessary_column_validator(x, "prompt"), |
| 686 | lambda x: necessary_column_validator(x, "completion"), |
| 687 | additional_column_validator, |
| 688 | non_empty_completion_validator, |
| 689 | format_inferrer_validator, |
| 690 | duplicated_rows_validator, |
| 691 | lambda x: lower_case_validator(x, "prompt"), |
| 692 | lambda x: lower_case_validator(x, "completion"), |
| 693 | common_prompt_suffix_validator, |
| 694 | common_prompt_prefix_validator, |
| 695 | common_completion_prefix_validator, |
| 696 | common_completion_suffix_validator, |
| 697 | completions_space_start_validator, |
| 698 | ] |