openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.12.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

examples/codex/backtranslation.py

189lines · modecode

1from typing import List, Union
2
3from smokey import Smokey
4
5import openai
6
7
8def get_candidates(
9 prompt: str,
10 stop: List[str],
11 temperature: float,
12 priming_prefix: str,
13 engine: str,
14 n: int = 5,
15) -> List[str]:
16 """
17 Generate N candidate completions based on the prompt, generated with a specific temperature.
18
19 :param prompt: The prompt to start the conversation with.
20 :param stop: A list of tokens that indicate the end of the generation.
21 :param temperature: The temperature of the generation.
22 :param priming_prefix: The prefix to use for the priming.
23 :param engine: The engine to use for the generation.
24 :param n: The number of completions to generate.
25 :return: A list of completions.
26 """
27 response = openai.Completion.create(
28 engine=engine,
29 prompt=prompt,
30 temperature=temperature,
31 max_tokens=150,
32 top_p=1,
33 frequency_penalty=0,
34 presence_penalty=0,
35 stop=stop,
36 n=n,
37 )
38 responses = [priming_prefix + choice.text for choice in response.choices]
39 return responses
40
41
42def rindex(lst: List, value: str) -> int:
43 """
44 Return the index of the last occurence of a value in a list.
45
46 :param lst: The list to search in.
47 :param value: The value to search for.
48 :return: The index of the last occurence of the value.
49 """
50 try:
51 return len(lst) - lst[::-1].index(value) - 1
52 except ValueError:
53 raise ValueError(f"Answer start token `{value}` not found in the eval template")
54
55
56def eval_candidate(
57 candidate_answer: str,
58 original_instruction: str,
59 eval_template: str,
60 answer_start_token: str,
61 engine: str,
62) -> float:
63 """
64 Evaluate a candidate answer by calculating the average log probability
65 of the original instruction, given the candidate answer with a specific
66 evaluation template, aimed at reconstructing the original instruction.
67
68 :param candidate_answer: The candidate answer to evaluate.
69 :param original_instruction: The original instruction.
70 :param eval_template: The template to use for the evaluation.
71 :param answer_start_token: The token to use to indicate the start of the answer.
72 :param engine: The engine to use for the evaluation.
73 :return: The evaluation of the candidate answer.
74 """
75 response = openai.Completion.create(
76 engine=engine,
77 prompt=eval_template.format(candidate_answer, original_instruction),
78 temperature=0,
79 max_tokens=0,
80 top_p=1,
81 frequency_penalty=0,
82 presence_penalty=0,
83 logprobs=1,
84 echo=True,
85 )
86
87 answer_start = rindex(
88 response["choices"][0]["logprobs"]["tokens"], answer_start_token
89 )
90 logprobs = response["choices"][0]["logprobs"]["token_logprobs"][answer_start + 1 :]
91 return sum(logprobs) / len(logprobs)
92
93
94def backtranslation(
95 prompt_template: str,
96 additional_info: str,
97 instruction: str,
98 eval_template: str,
99 priming_prefix: str = "SELECT",
100 stop1: List[str] = ["#", ";"],
101 answer_start_token: str = "--",
102 n: int = 5,
103 temperature: float = 0.5,
104 return_all_results: bool = False,
105 engine: str = "davinci-codex",
106) -> Union[str, List[str, float]]:
107 """
108 Generate a number of SQL queries given a natural language instruction,
109 and pick the best one based on the average log probability of explaining the
110 candidate SQL query with the exact original instruction, when prompted for
111 a natural language explanation of the candidate SQL query.
112
113 :param prompt_template: The template to use for the prompt to generate SQL.
114 :param additional_info: Additional information to include in the prompt
115 (SQL Tables, and their properties).
116 :param instruction: The instruction in natural language.
117 :param eval_template: The template to use for the evaluation.
118 :param priming_prefix: The prefix to use for the priming of the SQL query.
119 :param stop1: A list of tokens that indicate the end of the generation.
120 :param answer_start_token: The token to use to indicate the start of the
121 natural answer.
122 :param n: The number of candidates to generate.
123 :param temperature: The temperature of the generation.
124 :param return_all_results: Whether to return all results or just the best one.
125 :param engine: The engine to use for the generation and evaluation.
126 :return: The best SQL query, or a list of all scored generated SQL queries.
127 """
128 prompt_template = prompt_template.format(
129 additional_info, instruction, priming_prefix
130 )
131
132 candidates = []
133 responses = get_candidates(
134 prompt_template, stop1, temperature, priming_prefix, engine=engine, n=n
135 )
136 for i in range(n):
137 quality = eval_candidate(
138 responses[i],
139 instruction,
140 eval_template,
141 answer_start_token,
142 engine=engine,
143 )
144 candidates.append((responses[i], quality))
145
146 candidates.sort(key=lambda x: x[1], reverse=True)
147 if return_all_results:
148 return candidates
149 return candidates[0][0]
150
151
152def main(
153 nl_query: str = "Return the name of each department that had more than 10 employees in June 2021",
154 eval_template: str = "{};\n-- Explanation of the above query in human readable format\n-- {}",
155 table_definitions: str = "# Employee(id, name, department_id)\n# Department(id, name, address)\n# Salary_Payments(id, employee_id, amount, date)\n",
156 prompt_template: str = "### Postgres SQL tables, with their properties:\n#\n{}#\n### {}\n{}",
157 n: int = 3,
158 temperature: float = 0.3,
159 engine: str = "davinci-codex",
160):
161 """
162 Generate a number of SQL queries given a natural language instruction,
163 and pick the best one based on the highest backtranslation score.
164
165 :param nl_query: The natural language query.
166 :param eval_template: The template to use for the evaluation.
167 :param table_definitions: The definitions of the tables used in the query.
168 :param prompt_template: The template to use for the prompt to generate SQL.
169 :param n: The number of candidates to generate.
170 :param temperature: The temperature of the generation.
171 :param engine: The engine to use for the generation and evaluation.
172 :return: The best SQL query, or a list of all scored generated SQL queries.
173 """
174
175 result = backtranslation(
176 prompt_template,
177 table_definitions,
178 nl_query,
179 eval_template,
180 priming_prefix="SELECT",
181 temperature=temperature,
182 n=n,
183 engine=engine,
184 )
185 print(result)
186
187
188if __name__ == "__main__":
189 Smokey(main)
190