openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.9.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

examples/semanticsearch/semanticsearch.py

125lines · modecode

1#!/usr/bin/env python
2import openai
3import argparse
4import logging
5import sys
6from typing import List
7
8logger = logging.getLogger()
9formatter = logging.Formatter("[%(asctime)s] [%(process)d] %(message)s")
10handler = logging.StreamHandler(sys.stderr)
11handler.setFormatter(formatter)
12logger.addHandler(handler)
13
14DEFAULT_COND_LOGP_TEMPLATE = (
15 "<|endoftext|>{document}\n\n---\n\nThe above passage is related to: {query}"
16)
17SCORE_MULTIPLIER = 100.0
18
19
20class SearchScorer:
21 def __init__(
22 self, *, document, query, cond_logp_template=DEFAULT_COND_LOGP_TEMPLATE
23 ):
24 self.document = document
25 self.query = query
26 self.cond_logp_template = cond_logp_template
27 self.context = self.cond_logp_template.format(
28 document=self.document, query=self.query
29 )
30
31 def get_context(self):
32 return self.context
33
34 def get_score(self, choice) -> float:
35 assert choice.text == self.context
36 logprobs: List[float] = choice.logprobs.token_logprobs
37 text = choice.logprobs.tokens
38 text_len = sum(len(token) for token in text)
39 if text_len != len(self.context):
40 raise RuntimeError(
41 f"text_len={text_len}, len(self.context)={len(self.context)}"
42 )
43 total_len = 0
44 last_used = len(text)
45 while total_len < len(self.query):
46 assert last_used > 0
47 total_len += len(text[last_used - 1])
48 last_used -= 1
49 max_len = len(self.context) - self.cond_logp_template.index("{document}")
50 assert total_len + len(self.document) <= max_len
51 logits: List[float] = logprobs[last_used:]
52 return sum(logits) / len(logits) * SCORE_MULTIPLIER
53
54
55def semantic_search(engine, query, documents):
56 # add empty document as baseline
57 scorers = [
58 SearchScorer(document=document, query=query) for document in [""] + documents
59 ]
60 completion = openai.Completion.create(
61 engine=engine,
62 prompt=[scorer.get_context() for scorer in scorers],
63 max_tokens=0,
64 logprobs=0,
65 echo=True,
66 )
67 # put the documents back in order so we can easily normalize by the empty document 0
68 data = sorted(completion.choices, key=lambda choice: choice.index)
69 assert len(scorers) == len(
70 data
71 ), f"len(scorers)={len(scorers)} len(data)={len(data)}"
72 scores = [scorer.get_score(choice) for scorer, choice in zip(scorers, data)]
73 # subtract score for empty document
74 scores = [score - scores[0] for score in scores][1:]
75 data = {
76 "object": "list",
77 "data": [
78 {
79 "object": "search_result",
80 "document": document_idx,
81 "score": round(score, 3),
82 }
83 for document_idx, score in enumerate(scores)
84 ],
85 "model": completion.model,
86 }
87 return data
88
89
90def main():
91 parser = argparse.ArgumentParser(description=None)
92 parser.add_argument(
93 "-v",
94 "--verbose",
95 action="count",
96 dest="verbosity",
97 default=0,
98 help="Set verbosity.",
99 )
100 parser.add_argument("-e", "--engine", default="ada")
101 parser.add_argument("-q", "--query", required=True)
102 parser.add_argument("-d", "--document", action="append", required=True)
103 parser.add_argument("-s", "--server-side", action="store_true")
104 args = parser.parse_args()
105
106 if args.verbosity == 1:
107 logger.setLevel(logging.INFO)
108 elif args.verbosity >= 2:
109 logger.setLevel(logging.DEBUG)
110
111 if args.server_side:
112 resp = openai.Engine(id=args.engine).search(
113 query=args.query, documents=args.document
114 )
115 resp = resp.to_dict_recursive()
116 print(f"[server-side semantic search] {resp}")
117 else:
118 resp = semantic_search(args.engine, query=args.query, documents=args.document)
119 print(f"[client-side semantic search] {resp}")
120
121 return 0
122
123
124if __name__ == "__main__":
125 sys.exit(main())
126