openai/openai-python
Publicmirrored fromhttps://github.com/openai/openai-pythonAvailable
examples/semanticsearch/semanticsearch.py
126lines · modecode
| 1 | #!/usr/bin/env python |
| 2 | import argparse |
| 3 | import logging |
| 4 | import sys |
| 5 | from typing import List |
| 6 | |
| 7 | import openai |
| 8 | |
| 9 | logger = logging.getLogger() |
| 10 | formatter = logging.Formatter("[%(asctime)s] [%(process)d] %(message)s") |
| 11 | handler = logging.StreamHandler(sys.stderr) |
| 12 | handler.setFormatter(formatter) |
| 13 | logger.addHandler(handler) |
| 14 | |
| 15 | DEFAULT_COND_LOGP_TEMPLATE = ( |
| 16 | "<|endoftext|>{document}\n\n---\n\nThe above passage is related to: {query}" |
| 17 | ) |
| 18 | SCORE_MULTIPLIER = 100.0 |
| 19 | |
| 20 | |
| 21 | class SearchScorer: |
| 22 | def __init__( |
| 23 | self, *, document, query, cond_logp_template=DEFAULT_COND_LOGP_TEMPLATE |
| 24 | ): |
| 25 | self.document = document |
| 26 | self.query = query |
| 27 | self.cond_logp_template = cond_logp_template |
| 28 | self.context = self.cond_logp_template.format( |
| 29 | document=self.document, query=self.query |
| 30 | ) |
| 31 | |
| 32 | def get_context(self): |
| 33 | return self.context |
| 34 | |
| 35 | def get_score(self, choice) -> float: |
| 36 | assert choice.text == self.context |
| 37 | logprobs: List[float] = choice.logprobs.token_logprobs |
| 38 | text = choice.logprobs.tokens |
| 39 | text_len = sum(len(token) for token in text) |
| 40 | if text_len != len(self.context): |
| 41 | raise RuntimeError( |
| 42 | f"text_len={text_len}, len(self.context)={len(self.context)}" |
| 43 | ) |
| 44 | total_len = 0 |
| 45 | last_used = len(text) |
| 46 | while total_len < len(self.query): |
| 47 | assert last_used > 0 |
| 48 | total_len += len(text[last_used - 1]) |
| 49 | last_used -= 1 |
| 50 | max_len = len(self.context) - self.cond_logp_template.index("{document}") |
| 51 | assert total_len + len(self.document) <= max_len |
| 52 | logits: List[float] = logprobs[last_used:] |
| 53 | return sum(logits) / len(logits) * SCORE_MULTIPLIER |
| 54 | |
| 55 | |
| 56 | def semantic_search(engine, query, documents): |
| 57 | # add empty document as baseline |
| 58 | scorers = [ |
| 59 | SearchScorer(document=document, query=query) for document in [""] + documents |
| 60 | ] |
| 61 | completion = openai.Completion.create( |
| 62 | engine=engine, |
| 63 | prompt=[scorer.get_context() for scorer in scorers], |
| 64 | max_tokens=0, |
| 65 | logprobs=0, |
| 66 | echo=True, |
| 67 | ) |
| 68 | # put the documents back in order so we can easily normalize by the empty document 0 |
| 69 | data = sorted(completion.choices, key=lambda choice: choice.index) |
| 70 | assert len(scorers) == len( |
| 71 | data |
| 72 | ), f"len(scorers)={len(scorers)} len(data)={len(data)}" |
| 73 | scores = [scorer.get_score(choice) for scorer, choice in zip(scorers, data)] |
| 74 | # subtract score for empty document |
| 75 | scores = [score - scores[0] for score in scores][1:] |
| 76 | data = { |
| 77 | "object": "list", |
| 78 | "data": [ |
| 79 | { |
| 80 | "object": "search_result", |
| 81 | "document": document_idx, |
| 82 | "score": round(score, 3), |
| 83 | } |
| 84 | for document_idx, score in enumerate(scores) |
| 85 | ], |
| 86 | "model": completion.model, |
| 87 | } |
| 88 | return data |
| 89 | |
| 90 | |
| 91 | def main(): |
| 92 | parser = argparse.ArgumentParser(description=None) |
| 93 | parser.add_argument( |
| 94 | "-v", |
| 95 | "--verbose", |
| 96 | action="count", |
| 97 | dest="verbosity", |
| 98 | default=0, |
| 99 | help="Set verbosity.", |
| 100 | ) |
| 101 | parser.add_argument("-e", "--engine", default="ada") |
| 102 | parser.add_argument("-q", "--query", required=True) |
| 103 | parser.add_argument("-d", "--document", action="append", required=True) |
| 104 | parser.add_argument("-s", "--server-side", action="store_true") |
| 105 | args = parser.parse_args() |
| 106 | |
| 107 | if args.verbosity == 1: |
| 108 | logger.setLevel(logging.INFO) |
| 109 | elif args.verbosity >= 2: |
| 110 | logger.setLevel(logging.DEBUG) |
| 111 | |
| 112 | if args.server_side: |
| 113 | resp = openai.Engine(id=args.engine).search( |
| 114 | query=args.query, documents=args.document |
| 115 | ) |
| 116 | resp = resp.to_dict_recursive() |
| 117 | print(f"[server-side semantic search] {resp}") |
| 118 | else: |
| 119 | resp = semantic_search(args.engine, query=args.query, documents=args.document) |
| 120 | print(f"[client-side semantic search] {resp}") |
| 121 | |
| 122 | return 0 |
| 123 | |
| 124 | |
| 125 | if __name__ == "__main__": |
| 126 | sys.exit(main()) |
| 127 | |