openai/openai-python
Publicmirrored fromhttps://github.com/openai/openai-pythonAvailable
examples/semanticsearch/semanticsearch.py
125lines · modecode
| 1 | #!/usr/bin/env python |
| 2 | import openai |
| 3 | import argparse |
| 4 | import logging |
| 5 | import sys |
| 6 | from typing import List |
| 7 | |
| 8 | logger = logging.getLogger() |
| 9 | formatter = logging.Formatter("[%(asctime)s] [%(process)d] %(message)s") |
| 10 | handler = logging.StreamHandler(sys.stderr) |
| 11 | handler.setFormatter(formatter) |
| 12 | logger.addHandler(handler) |
| 13 | |
| 14 | DEFAULT_COND_LOGP_TEMPLATE = ( |
| 15 | "<|endoftext|>{document}\n\n---\n\nThe above passage is related to: {query}" |
| 16 | ) |
| 17 | SCORE_MULTIPLIER = 100.0 |
| 18 | |
| 19 | |
| 20 | class SearchScorer: |
| 21 | def __init__( |
| 22 | self, *, document, query, cond_logp_template=DEFAULT_COND_LOGP_TEMPLATE |
| 23 | ): |
| 24 | self.document = document |
| 25 | self.query = query |
| 26 | self.cond_logp_template = cond_logp_template |
| 27 | self.context = self.cond_logp_template.format( |
| 28 | document=self.document, query=self.query |
| 29 | ) |
| 30 | |
| 31 | def get_context(self): |
| 32 | return self.context |
| 33 | |
| 34 | def get_score(self, choice) -> float: |
| 35 | assert choice.text == self.context |
| 36 | logprobs: List[float] = choice.logprobs.token_logprobs |
| 37 | text = choice.logprobs.tokens |
| 38 | text_len = sum(len(token) for token in text) |
| 39 | if text_len != len(self.context): |
| 40 | raise RuntimeError( |
| 41 | f"text_len={text_len}, len(self.context)={len(self.context)}" |
| 42 | ) |
| 43 | total_len = 0 |
| 44 | last_used = len(text) |
| 45 | while total_len < len(self.query): |
| 46 | assert last_used > 0 |
| 47 | total_len += len(text[last_used - 1]) |
| 48 | last_used -= 1 |
| 49 | max_len = len(self.context) - self.cond_logp_template.index("{document}") |
| 50 | assert total_len + len(self.document) <= max_len |
| 51 | logits: List[float] = logprobs[last_used:] |
| 52 | return sum(logits) / len(logits) * SCORE_MULTIPLIER |
| 53 | |
| 54 | |
| 55 | def semantic_search(engine, query, documents): |
| 56 | # add empty document as baseline |
| 57 | scorers = [ |
| 58 | SearchScorer(document=document, query=query) for document in [""] + documents |
| 59 | ] |
| 60 | completion = openai.Completion.create( |
| 61 | engine=engine, |
| 62 | prompt=[scorer.get_context() for scorer in scorers], |
| 63 | max_tokens=0, |
| 64 | logprobs=0, |
| 65 | echo=True, |
| 66 | ) |
| 67 | # put the documents back in order so we can easily normalize by the empty document 0 |
| 68 | data = sorted(completion.choices, key=lambda choice: choice.index) |
| 69 | assert len(scorers) == len( |
| 70 | data |
| 71 | ), f"len(scorers)={len(scorers)} len(data)={len(data)}" |
| 72 | scores = [scorer.get_score(choice) for scorer, choice in zip(scorers, data)] |
| 73 | # subtract score for empty document |
| 74 | scores = [score - scores[0] for score in scores][1:] |
| 75 | data = { |
| 76 | "object": "list", |
| 77 | "data": [ |
| 78 | { |
| 79 | "object": "search_result", |
| 80 | "document": document_idx, |
| 81 | "score": round(score, 3), |
| 82 | } |
| 83 | for document_idx, score in enumerate(scores) |
| 84 | ], |
| 85 | "model": completion.model, |
| 86 | } |
| 87 | return data |
| 88 | |
| 89 | |
| 90 | def main(): |
| 91 | parser = argparse.ArgumentParser(description=None) |
| 92 | parser.add_argument( |
| 93 | "-v", |
| 94 | "--verbose", |
| 95 | action="count", |
| 96 | dest="verbosity", |
| 97 | default=0, |
| 98 | help="Set verbosity.", |
| 99 | ) |
| 100 | parser.add_argument("-e", "--engine", default="ada") |
| 101 | parser.add_argument("-q", "--query", required=True) |
| 102 | parser.add_argument("-d", "--document", action="append", required=True) |
| 103 | parser.add_argument("-s", "--server-side", action="store_true") |
| 104 | args = parser.parse_args() |
| 105 | |
| 106 | if args.verbosity == 1: |
| 107 | logger.setLevel(logging.INFO) |
| 108 | elif args.verbosity >= 2: |
| 109 | logger.setLevel(logging.DEBUG) |
| 110 | |
| 111 | if args.server_side: |
| 112 | resp = openai.Engine(id=args.engine).search( |
| 113 | query=args.query, documents=args.document |
| 114 | ) |
| 115 | resp = resp.to_dict_recursive() |
| 116 | print(f"[server-side semantic search] {resp}") |
| 117 | else: |
| 118 | resp = semantic_search(args.engine, query=args.query, documents=args.document) |
| 119 | print(f"[client-side semantic search] {resp}") |
| 120 | |
| 121 | return 0 |
| 122 | |
| 123 | |
| 124 | if __name__ == "__main__": |
| 125 | sys.exit(main()) |
| 126 | |