cloudflare/cloudflared

Public

mirrored fromhttps://github.com/cloudflare/cloudflaredAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2020.6.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

dbconnect/sql.go

318lines · modecode

1package dbconnect
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "net/url"
9 "reflect"
10 "strings"
11
12 "github.com/jmoiron/sqlx"
13 "github.com/pkg/errors"
14 "github.com/xo/dburl"
15
16 // SQL drivers self-register with the database/sql package.
17 // https://github.com/golang/go/wiki/SQLDrivers
18 _ "github.com/denisenkom/go-mssqldb"
19 _ "github.com/go-sql-driver/mysql"
20 _ "github.com/mattn/go-sqlite3"
21
22 "github.com/kshvakov/clickhouse"
23 "github.com/lib/pq"
24)
25
26// SQLClient is a Client that talks to a SQL database.
27type SQLClient struct {
28 Dialect string
29 driver *sqlx.DB
30}
31
32// NewSQLClient creates a SQL client based on its URL scheme.
33func NewSQLClient(ctx context.Context, originURL *url.URL) (Client, error) {
34 res, err := dburl.Parse(originURL.String())
35 if err != nil {
36 helpText := fmt.Sprintf("supported drivers: %+q, see documentation for more details: %s", sql.Drivers(), "https://godoc.org/github.com/xo/dburl")
37 return nil, fmt.Errorf("could not parse sql database url '%s': %s\n%s", originURL, err.Error(), helpText)
38 }
39
40 // Establishes the driver, but does not test the connection.
41 driver, err := sqlx.Open(res.Driver, res.DSN)
42 if err != nil {
43 return nil, fmt.Errorf("could not open sql driver %s: %s\n%s", res.Driver, err.Error(), res.DSN)
44 }
45
46 // Closes the driver, will occur when the context finishes.
47 go func() {
48 <-ctx.Done()
49 driver.Close()
50 }()
51
52 return &SQLClient{driver.DriverName(), driver}, nil
53}
54
55// Ping verifies a connection to the database is still alive.
56func (client *SQLClient) Ping(ctx context.Context) error {
57 return client.driver.PingContext(ctx)
58}
59
60// Submit queries or executes a command to the SQL database.
61func (client *SQLClient) Submit(ctx context.Context, cmd *Command) (interface{}, error) {
62 txx, err := cmd.ValidateSQL(client.Dialect)
63 if err != nil {
64 return nil, err
65 }
66
67 ctx, cancel := context.WithTimeout(ctx, cmd.Timeout)
68 defer cancel()
69
70 var res interface{}
71
72 // Get the next available sql.Conn and submit the Command.
73 err = sqlConn(ctx, client.driver, txx, func(conn *sql.Conn) error {
74 stmt := cmd.Statement
75 args := cmd.Arguments.Positional
76
77 if cmd.Mode == "query" {
78 res, err = sqlQuery(ctx, conn, stmt, args)
79 } else {
80 res, err = sqlExec(ctx, conn, stmt, args)
81 }
82
83 return err
84 })
85
86 return res, err
87}
88
89// ValidateSQL extends the contract of Command for SQL dialects:
90// mode is conformed, arguments are []sql.NamedArg, and isolation is a sql.IsolationLevel.
91//
92// When the command should not be wrapped in a transaction, *sql.TxOptions and error will both be nil.
93func (cmd *Command) ValidateSQL(dialect string) (*sql.TxOptions, error) {
94 err := cmd.Validate()
95 if err != nil {
96 return nil, err
97 }
98
99 mode, err := sqlMode(cmd.Mode)
100 if err != nil {
101 return nil, err
102 }
103
104 // Mutates Arguments to only use positional arguments with the type sql.NamedArg.
105 // This is a required by the sql.Driver before submitting arguments.
106 cmd.Arguments.sql(dialect)
107
108 iso, err := sqlIsolation(cmd.Isolation)
109 if err != nil {
110 return nil, err
111 }
112
113 // When isolation is out-of-range, this is indicative that no
114 // transaction should be executed and sql.TxOptions should be nil.
115 if iso < sql.LevelDefault {
116 return nil, nil
117 }
118
119 // In query mode, execute the transaction in read-only, unless it's Microsoft SQL
120 // which does not support that type of transaction.
121 readOnly := mode == "query" && dialect != "mssql"
122
123 return &sql.TxOptions{Isolation: iso, ReadOnly: readOnly}, nil
124}
125
126// sqlConn gets the next available sql.Conn in the connection pool and runs a function to use it.
127//
128// If the transaction options are nil, run the useIt function outside a transaction.
129// This is potentially an unsafe operation if the command does not clean up its state.
130func sqlConn(ctx context.Context, driver *sqlx.DB, txx *sql.TxOptions, useIt func(*sql.Conn) error) error {
131 conn, err := driver.Conn(ctx)
132 if err != nil {
133 return err
134 }
135 defer conn.Close()
136
137 // If transaction options are specified, begin and defer a rollback to catch errors.
138 var tx *sql.Tx
139 if txx != nil {
140 tx, err = conn.BeginTx(ctx, txx)
141 if err != nil {
142 return err
143 }
144 defer tx.Rollback()
145 }
146
147 err = useIt(conn)
148
149 // Check if useIt was successful and a transaction exists before committing.
150 if err == nil && tx != nil {
151 err = tx.Commit()
152 }
153
154 return err
155}
156
157// sqlQuery queries rows on a sql.Conn and returns an array of result objects.
158func sqlQuery(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) ([]map[string]interface{}, error) {
159 rows, err := conn.QueryContext(ctx, stmt, args...)
160 if err == nil {
161 return sqlRows(rows)
162 }
163 return nil, err
164}
165
166// sqlExec executes a command on a sql.Conn and returns the result of the operation.
167func sqlExec(ctx context.Context, conn *sql.Conn, stmt string, args []interface{}) (sqlResult, error) {
168 exec, err := conn.ExecContext(ctx, stmt, args...)
169 if err == nil {
170 return sqlResultFrom(exec), nil
171 }
172 return sqlResult{}, err
173}
174
175// sql mutates Arguments to contain a positional []sql.NamedArg.
176//
177// The actual return type is []interface{} due to the native Golang
178// function signatures for sql.Exec and sql.Query being generic.
179func (args *Arguments) sql(dialect string) {
180 result := args.Positional
181
182 for i, val := range result {
183 result[i] = sqlArg("", val, dialect)
184 }
185
186 for key, val := range args.Named {
187 result = append(result, sqlArg(key, val, dialect))
188 }
189
190 args.Positional = result
191 args.Named = map[string]interface{}{}
192}
193
194// sqlArg creates a sql.NamedArg from a key-value pair and an optional dialect.
195//
196// Certain dialects will need to wrap objects, such as arrays, to conform its driver requirements.
197func sqlArg(key, val interface{}, dialect string) sql.NamedArg {
198 switch reflect.ValueOf(val).Kind() {
199
200 // PostgreSQL and Clickhouse require arrays to be wrapped before
201 // being inserted into the driver interface.
202 case reflect.Slice, reflect.Array:
203 switch dialect {
204 case "postgres":
205 val = pq.Array(val)
206 case "clickhouse":
207 val = clickhouse.Array(val)
208 }
209 }
210
211 return sql.Named(fmt.Sprint(key), val)
212}
213
214// sqlIsolation tries to match a string to a sql.IsolationLevel.
215func sqlIsolation(str string) (sql.IsolationLevel, error) {
216 if str == "none" {
217 return sql.IsolationLevel(-1), nil
218 }
219
220 for iso := sql.LevelDefault; ; iso++ {
221 if iso > sql.LevelLinearizable {
222 return -1, fmt.Errorf("cannot provide an invalid sql isolation level: '%s'", str)
223 }
224
225 if str == "" || strings.EqualFold(iso.String(), strings.ReplaceAll(str, "_", " ")) {
226 return iso, nil
227 }
228 }
229}
230
231// sqlMode tries to match a string to a command mode: 'query' or 'exec' for now.
232func sqlMode(str string) (string, error) {
233 switch str {
234 case "query", "exec":
235 return str, nil
236 default:
237 return "", fmt.Errorf("cannot provide invalid sql mode: '%s'", str)
238 }
239}
240
241// sqlRows scans through a SQL result set and returns an array of objects.
242func sqlRows(rows *sql.Rows) ([]map[string]interface{}, error) {
243 columns, err := rows.Columns()
244 if err != nil {
245 return nil, errors.Wrap(err, "could not extract columns from result")
246 }
247 defer rows.Close()
248
249 types, err := rows.ColumnTypes()
250 if err != nil {
251 // Some drivers do not support type extraction, so fail silently and continue.
252 types = make([]*sql.ColumnType, len(columns))
253 }
254
255 values := make([]interface{}, len(columns))
256 pointers := make([]interface{}, len(columns))
257
258 var results []map[string]interface{}
259 for rows.Next() {
260 for i := range columns {
261 pointers[i] = &values[i]
262 }
263 rows.Scan(pointers...)
264
265 // Convert a row, an array of values, into an object where
266 // each key is the name of its respective column.
267 entry := make(map[string]interface{})
268 for i, col := range columns {
269 entry[col] = sqlValue(values[i], types[i])
270 }
271 results = append(results, entry)
272 }
273
274 return results, nil
275}
276
277// sqlValue handles special cases where sql.Rows does not return a "human-readable" object.
278func sqlValue(val interface{}, col *sql.ColumnType) interface{} {
279 bytes, ok := val.([]byte)
280 if ok {
281 // Opportunistically check for embeded JSON and convert it to a first-class object.
282 var embeded interface{}
283 if json.Unmarshal(bytes, &embeded) == nil {
284 return embeded
285 }
286
287 // STOR-604: investigate a way to coerce PostgreSQL arrays '{a, b, ...}' into JSON.
288 // Although easy with strings, it becomes more difficult with special types like INET[].
289
290 return string(bytes)
291 }
292
293 return val
294}
295
296// sqlResult is a thin wrapper around sql.Result.
297type sqlResult struct {
298 LastInsertId int64 `json:"last_insert_id"`
299 RowsAffected int64 `json:"rows_affected"`
300}
301
302// sqlResultFrom converts sql.Result into a JSON-marshable sqlResult.
303func sqlResultFrom(res sql.Result) sqlResult {
304 insertID, errID := res.LastInsertId()
305 rowsAffected, errRows := res.RowsAffected()
306
307 // If an error occurs when extracting the result, it is because the
308 // driver does not support that specific field. Instead of passing this
309 // to the user, omit the field in the response.
310 if errID != nil {
311 insertID = -1
312 }
313 if errRows != nil {
314 rowsAffected = -1
315 }
316
317 return sqlResult{insertID, rowsAffected}
318}
319