openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.11.6

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

examples/semanticsearch/semanticsearch.py

126lines · modecode

1#!/usr/bin/env python
2import argparse
3import logging
4import sys
5from typing import List
6
7import openai
8
9logger = logging.getLogger()
10formatter = logging.Formatter("[%(asctime)s] [%(process)d] %(message)s")
11handler = logging.StreamHandler(sys.stderr)
12handler.setFormatter(formatter)
13logger.addHandler(handler)
14
15DEFAULT_COND_LOGP_TEMPLATE = (
16 "<|endoftext|>{document}\n\n---\n\nThe above passage is related to: {query}"
17)
18SCORE_MULTIPLIER = 100.0
19
20
21class SearchScorer:
22 def __init__(
23 self, *, document, query, cond_logp_template=DEFAULT_COND_LOGP_TEMPLATE
24 ):
25 self.document = document
26 self.query = query
27 self.cond_logp_template = cond_logp_template
28 self.context = self.cond_logp_template.format(
29 document=self.document, query=self.query
30 )
31
32 def get_context(self):
33 return self.context
34
35 def get_score(self, choice) -> float:
36 assert choice.text == self.context
37 logprobs: List[float] = choice.logprobs.token_logprobs
38 text = choice.logprobs.tokens
39 text_len = sum(len(token) for token in text)
40 if text_len != len(self.context):
41 raise RuntimeError(
42 f"text_len={text_len}, len(self.context)={len(self.context)}"
43 )
44 total_len = 0
45 last_used = len(text)
46 while total_len < len(self.query):
47 assert last_used > 0
48 total_len += len(text[last_used - 1])
49 last_used -= 1
50 max_len = len(self.context) - self.cond_logp_template.index("{document}")
51 assert total_len + len(self.document) <= max_len
52 logits: List[float] = logprobs[last_used:]
53 return sum(logits) / len(logits) * SCORE_MULTIPLIER
54
55
56def semantic_search(engine, query, documents):
57 # add empty document as baseline
58 scorers = [
59 SearchScorer(document=document, query=query) for document in [""] + documents
60 ]
61 completion = openai.Completion.create(
62 engine=engine,
63 prompt=[scorer.get_context() for scorer in scorers],
64 max_tokens=0,
65 logprobs=0,
66 echo=True,
67 )
68 # put the documents back in order so we can easily normalize by the empty document 0
69 data = sorted(completion.choices, key=lambda choice: choice.index)
70 assert len(scorers) == len(
71 data
72 ), f"len(scorers)={len(scorers)} len(data)={len(data)}"
73 scores = [scorer.get_score(choice) for scorer, choice in zip(scorers, data)]
74 # subtract score for empty document
75 scores = [score - scores[0] for score in scores][1:]
76 data = {
77 "object": "list",
78 "data": [
79 {
80 "object": "search_result",
81 "document": document_idx,
82 "score": round(score, 3),
83 }
84 for document_idx, score in enumerate(scores)
85 ],
86 "model": completion.model,
87 }
88 return data
89
90
91def main():
92 parser = argparse.ArgumentParser(description=None)
93 parser.add_argument(
94 "-v",
95 "--verbose",
96 action="count",
97 dest="verbosity",
98 default=0,
99 help="Set verbosity.",
100 )
101 parser.add_argument("-e", "--engine", default="ada")
102 parser.add_argument("-q", "--query", required=True)
103 parser.add_argument("-d", "--document", action="append", required=True)
104 parser.add_argument("-s", "--server-side", action="store_true")
105 args = parser.parse_args()
106
107 if args.verbosity == 1:
108 logger.setLevel(logging.INFO)
109 elif args.verbosity >= 2:
110 logger.setLevel(logging.DEBUG)
111
112 if args.server_side:
113 resp = openai.Engine(id=args.engine).search(
114 query=args.query, documents=args.document
115 )
116 resp = resp.to_dict_recursive()
117 print(f"[server-side semantic search] {resp}")
118 else:
119 resp = semantic_search(args.engine, query=args.query, documents=args.document)
120 print(f"[client-side semantic search] {resp}")
121
122 return 0
123
124
125if __name__ == "__main__":
126 sys.exit(main())
127