openai/openai-python
Publicmirrored fromhttps://github.com/openai/openai-pythonAvailable
examples/finetuning/answers_with_ft.py
150lines · modecode
| 1 | import argparse |
| 2 | |
| 3 | import openai |
| 4 | |
| 5 | |
| 6 | def create_context( |
| 7 | question, search_file_id, max_len=1800, search_model="ada", max_rerank=10 |
| 8 | ): |
| 9 | """ |
| 10 | Create a context for a question by finding the most similar context from the search file. |
| 11 | :param question: The question |
| 12 | :param search_file_id: The file id of the search file |
| 13 | :param max_len: The maximum length of the returned context (in tokens) |
| 14 | :param search_model: The search model to use |
| 15 | :param max_rerank: The maximum number of reranking |
| 16 | :return: The context |
| 17 | """ |
| 18 | results = openai.Engine(search_model).search( |
| 19 | search_model=search_model, |
| 20 | query=question, |
| 21 | max_rerank=max_rerank, |
| 22 | file=search_file_id, |
| 23 | return_metadata=True, |
| 24 | ) |
| 25 | returns = [] |
| 26 | cur_len = 0 |
| 27 | for result in results["data"]: |
| 28 | cur_len += int(result["metadata"]) + 4 |
| 29 | if cur_len > max_len: |
| 30 | break |
| 31 | returns.append(result["text"]) |
| 32 | return "\n\n###\n\n".join(returns) |
| 33 | |
| 34 | |
| 35 | def answer_question( |
| 36 | search_file_id="<SEARCH_FILE_ID>", |
| 37 | fine_tuned_qa_model="<FT_QA_MODEL_ID>", |
| 38 | question="Which country won the European Football championship in 2021?", |
| 39 | max_len=1800, |
| 40 | search_model="ada", |
| 41 | max_rerank=10, |
| 42 | debug=False, |
| 43 | stop_sequence=["\n", "."], |
| 44 | max_tokens=100, |
| 45 | ): |
| 46 | """ |
| 47 | Answer a question based on the most similar context from the search file, using your fine-tuned model. |
| 48 | :param question: The question |
| 49 | :param fine_tuned_qa_model: The fine tuned QA model |
| 50 | :param search_file_id: The file id of the search file |
| 51 | :param max_len: The maximum length of the returned context (in tokens) |
| 52 | :param search_model: The search model to use |
| 53 | :param max_rerank: The maximum number of reranking |
| 54 | :param debug: Whether to output debug information |
| 55 | :param stop_sequence: The stop sequence for Q&A model |
| 56 | :param max_tokens: The maximum number of tokens to return |
| 57 | :return: The answer |
| 58 | """ |
| 59 | context = create_context( |
| 60 | question, |
| 61 | search_file_id, |
| 62 | max_len=max_len, |
| 63 | search_model=search_model, |
| 64 | max_rerank=max_rerank, |
| 65 | ) |
| 66 | if debug: |
| 67 | print("Context:\n" + context) |
| 68 | print("\n\n") |
| 69 | try: |
| 70 | # fine-tuned models requires model parameter, whereas other models require engine parameter |
| 71 | model_param = ( |
| 72 | {"model": fine_tuned_qa_model} |
| 73 | if ":" in fine_tuned_qa_model |
| 74 | and fine_tuned_qa_model.split(":")[1].startswith("ft") |
| 75 | else {"engine": fine_tuned_qa_model} |
| 76 | ) |
| 77 | response = openai.Completion.create( |
| 78 | prompt=f"Answer the question based on the context below\n\nText: {context}\n\n---\n\nQuestion: {question}\nAnswer:", |
| 79 | temperature=0, |
| 80 | max_tokens=max_tokens, |
| 81 | top_p=1, |
| 82 | frequency_penalty=0, |
| 83 | presence_penalty=0, |
| 84 | stop=stop_sequence, |
| 85 | **model_param, |
| 86 | ) |
| 87 | return response["choices"][0]["text"] |
| 88 | except Exception as e: |
| 89 | print(e) |
| 90 | return "" |
| 91 | |
| 92 | |
| 93 | if __name__ == "__main__": |
| 94 | parser = argparse.ArgumentParser( |
| 95 | description="Rudimentary functionality of the answers endpoint with a fine-tuned Q&A model.", |
| 96 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| 97 | ) |
| 98 | parser.add_argument( |
| 99 | "--search_file_id", help="Search file id", required=True, type=str |
| 100 | ) |
| 101 | parser.add_argument( |
| 102 | "--fine_tuned_qa_model", help="Fine-tuned QA model id", required=True, type=str |
| 103 | ) |
| 104 | parser.add_argument( |
| 105 | "--question", help="Question to answer", required=True, type=str |
| 106 | ) |
| 107 | parser.add_argument( |
| 108 | "--max_len", |
| 109 | help="Maximum length of the returned context (in tokens)", |
| 110 | default=1800, |
| 111 | type=int, |
| 112 | ) |
| 113 | parser.add_argument( |
| 114 | "--search_model", help="Search model to use", default="ada", type=str |
| 115 | ) |
| 116 | parser.add_argument( |
| 117 | "--max_rerank", |
| 118 | help="Maximum number of reranking for the search", |
| 119 | default=10, |
| 120 | type=int, |
| 121 | ) |
| 122 | parser.add_argument( |
| 123 | "--debug", help="Print debug information (context used)", action="store_true" |
| 124 | ) |
| 125 | parser.add_argument( |
| 126 | "--stop_sequence", |
| 127 | help="Stop sequences for the Q&A model", |
| 128 | default=["\n", "."], |
| 129 | nargs="+", |
| 130 | type=str, |
| 131 | ) |
| 132 | parser.add_argument( |
| 133 | "--max_tokens", |
| 134 | help="Maximum number of tokens to return", |
| 135 | default=100, |
| 136 | type=int, |
| 137 | ) |
| 138 | args = parser.parse_args() |
| 139 | response = answer_question( |
| 140 | search_file_id=args.search_file_id, |
| 141 | fine_tuned_qa_model=args.fine_tuned_qa_model, |
| 142 | question=args.question, |
| 143 | max_len=args.max_len, |
| 144 | search_model=args.search_model, |
| 145 | max_rerank=args.max_rerank, |
| 146 | debug=args.debug, |
| 147 | stop_sequence=args.stop_sequence, |
| 148 | max_tokens=args.max_tokens, |
| 149 | ) |
| 150 | print(f"Answer:{response}") |
| 151 | |