-
Notifications
You must be signed in to change notification settings - Fork 379
Expand file tree
/
Copy pathmongo.go
More file actions
261 lines (212 loc) · 7.08 KB
/
mongo.go
File metadata and controls
261 lines (212 loc) · 7.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
// Copyright (C) MongoDB, Inc. 2014-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package util
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
const (
InvalidDBChars = "/\\. \"\x00$"
InvalidCollectionChars = "$\x00"
DefaultHost = "localhost"
DefaultPort = "27017"
)
// Extract the replica set name and the list of hosts from the connection string.
func SplitHostArg(connString string) ([]string, string) {
// strip off the replica set name from the beginning
slashIndex := strings.Index(connString, "/")
setName := ""
if slashIndex != -1 {
setName = connString[:slashIndex]
if slashIndex == len(connString)-1 {
return []string{""}, setName
}
connString = connString[slashIndex+1:]
}
// split the hosts, and return them and the set name
return strings.Split(connString, ","), setName
}
// Split the host string into the individual nodes to connect to, appending the
// port if necessary.
func CreateConnectionAddrs(host, port string) []string {
// set to the defaults, if necessary
if host == "" {
host = DefaultHost
if port == "" {
host += fmt.Sprintf(":%v", DefaultPort)
}
}
// parse the host string into the individual hosts
addrs, _ := SplitHostArg(host)
// if a port is specified, append it to all the hosts
if port != "" {
for idx, addr := range addrs {
addrs[idx] = fmt.Sprintf("%v:%v", addr, port)
}
}
return addrs
}
// BuildURI assembles a URI from host and port arguments, including a possible
// replica set name on the host part.
func BuildURI(host, port string) string {
seedlist, setname := SplitHostArg(host)
// if any seedlist entry is empty, make it localhost
for i := range seedlist {
if seedlist[i] == "" {
seedlist[i] = "localhost"
}
}
// if a port is provided, append it to any host without a port; if any
// host part is empty string, make it localhost
if port != "" {
for i := range seedlist {
if strings.Index(seedlist[i], ":") == -1 {
seedlist[i] = seedlist[i] + ":" + port
}
}
}
hostpairs := strings.Join(seedlist, ",")
if setname != "" {
return fmt.Sprintf("mongodb://%s/?replicaSet=%s", hostpairs, setname)
}
return fmt.Sprintf("mongodb://%s/", hostpairs)
}
// SplitNamespace splits a namespace path into a database and collection,
// returned in that order.
func SplitNamespace(namespace string) (string, string) {
// find the first instance of "." in the namespace
firstDotIndex := strings.Index(namespace, ".")
// split the namespace, if applicable
var database string
var collection string
if firstDotIndex != -1 {
database = namespace[:firstDotIndex]
collection = namespace[firstDotIndex+1:]
} else {
database = namespace
}
return database, collection
}
// SplitAndValidateNamespace splits a namespace path into a database and collection,
// returned in that order. An error is returned if the namespace is invalid.
func SplitAndValidateNamespace(namespace string) (string, string, error) {
// first, run validation checks
if err := ValidateFullNamespace(namespace); err != nil {
return "", "", fmt.Errorf("namespace '%v' is not valid: %v",
namespace, err)
}
database, collection := SplitNamespace(namespace)
return database, collection, nil
}
// ValidateFullNamespace validates a full mongodb namespace (database +
// collection), returning an error if it is invalid.
func ValidateFullNamespace(namespace string) error {
// the namespace must be shorter than 123 bytes
if len([]byte(namespace)) > 122 {
return fmt.Errorf("namespace %v is too long (>= 123 bytes)", namespace)
}
// find the first instance of "." in the namespace
firstDotIndex := strings.Index(namespace, ".")
// the namespace cannot begin with a dot
if firstDotIndex == 0 {
return fmt.Errorf("namespace %v begins with a '.'", namespace)
}
// the namespace cannot end with a dot
if firstDotIndex == len(namespace)-1 {
return fmt.Errorf("namespace %v ends with a '.'", namespace)
}
// split the namespace, if applicable
var database string
var collection string
if firstDotIndex != -1 {
database = namespace[:firstDotIndex]
collection = namespace[firstDotIndex+1:]
} else {
database = namespace
}
// validate the database name
dbValidationErr := ValidateDBName(database)
if dbValidationErr != nil {
return fmt.Errorf("database name is invalid: %v", dbValidationErr)
}
// validate the collection name, if necessary
if collection != "" {
collValidationErr := ValidateCollectionName(collection)
if collValidationErr != nil {
return fmt.Errorf("collection name is invalid: %v",
collValidationErr)
}
}
// the namespace is valid
return nil
}
// ValidateDBName validates that a string is a valid name for a mongodb
// database. An error is returned if it is not valid.
func ValidateDBName(database string) error {
// must be < 64 characters
if len([]byte(database)) > 63 {
return fmt.Errorf("db name '%v' is longer than 63 characters", database)
}
// check for illegal characters
for _, illegalRune := range InvalidDBChars {
if strings.ContainsRune(database, illegalRune) {
return fmt.Errorf("illegal character '%c' found in db name '%v'", illegalRune, database)
}
}
// db name is valid
return nil
}
// ValidateCollectionName validates that a string is a valid name for a mongodb
// collection. An error is returned if it is not valid.
func ValidateCollectionName(collection string) error {
// collection names cannot begin with 'system.'
if strings.HasPrefix(collection, "system.") {
return fmt.Errorf("collection name '%v' is not allowed to begin with"+
" 'system.'", collection)
}
return ValidateCollectionGrammar(collection)
}
// ValidateCollectionGrammar validates the collection for character and length
// errors without erroring on system collections. For validation of functionality
// that manipulates system collections.
func ValidateCollectionGrammar(collection string) error {
// collection names cannot be empty
if len(collection) == 0 {
return fmt.Errorf("collection name cannot be an empty string")
}
// check for illegal characters
for _, illegalRune := range InvalidCollectionChars {
if strings.ContainsRune(collection, illegalRune) {
return fmt.Errorf("illegal character '%c' found in '%v'", illegalRune, collection)
}
}
// collection name is valid
return nil
}
func IsConnectionAuthenticated(ctx context.Context, conn *mongo.Client) (bool, error) {
res := conn.Database("admin").RunCommand(
ctx,
bson.D{{"connectionStatus", 1}},
)
if res.Err() != nil {
return false, errors.Wrap(res.Err(), "failed to query for connection information")
}
body := struct {
AuthInfo struct {
AuthenticatedUsers []bson.D `bson:"authenticatedUsers"`
} `bson:"authInfo"`
}{}
err := res.Decode(&body)
if err != nil {
raw, _ := res.Raw()
return false, errors.Wrapf(err, "failed to decode connection information (%+v)", raw)
}
return len(body.AuthInfo.AuthenticatedUsers) > 0, nil
}