2
0

replace dsn parser with simple parser, rather than regex

This commit is contained in:
Joshua Barone
2019-09-12 10:13:13 -05:00
parent 856c67a8c8
commit 2d9d8dc52a
2 changed files with 149 additions and 18 deletions
+61 -18
View File
@@ -17,7 +17,6 @@ import (
"os/user" "os/user"
"path/filepath" "path/filepath"
"reflect" "reflect"
"regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -1062,7 +1061,7 @@ func ParseURI(uri string) (ConnConfig, error) {
return cp, nil return cp, nil
} }
var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`) var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
// ParseDSN parses a database DSN (data source name) into a ConnConfig // ParseDSN parses a database DSN (data source name) into a ConnConfig
// //
@@ -1078,35 +1077,79 @@ var dsnRegexp = regexp.MustCompile(`([a-zA-Z_]+)=((?:"[^"]+")|(?:[^ ]+))`)
func ParseDSN(s string) (ConnConfig, error) { func ParseDSN(s string) (ConnConfig, error) {
var cp ConnConfig var cp ConnConfig
m := dsnRegexp.FindAllStringSubmatch(s, -1)
tlsArgs := configTLSArgs{} tlsArgs := configTLSArgs{}
cp.RuntimeParams = make(map[string]string) cp.RuntimeParams = make(map[string]string)
var hostval, portval string var hostval, portval string
for _, b := range m { for len(s) > 0 {
switch b[1] { var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return cp, errors.New("invalid dsn")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
if s[0] != '\'' {
end := 0
for ; end < len(s); end++ {
if asciiSpace[s[end]] == 1 {
break
}
if s[end] == '\\' {
end++
}
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
} else { // quoted string
s = s[1:]
end := 0
for ; end < len(s); end++ {
if s[end] == '\'' {
break
}
if s[end] == '\\' {
end++
}
}
if end == len(s) {
return cp, errors.New("unterminated quoted string in connection info string")
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
}
switch key {
case "user": case "user":
cp.User = b[2] cp.User = val
case "password": case "password":
cp.Password = b[2] cp.Password = val
case "host": case "host":
hostval = b[2] hostval = val
case "port": case "port":
portval = b[2] portval = val
case "dbname": case "dbname":
cp.Database = b[2] cp.Database = val
case "sslmode": case "sslmode":
tlsArgs.sslMode = b[2] tlsArgs.sslMode = val
case "sslrootcert": case "sslrootcert":
tlsArgs.sslRootCert = b[2] tlsArgs.sslRootCert = val
case "sslcert": case "sslcert":
tlsArgs.sslCert = b[2] tlsArgs.sslCert = val
case "sslkey": case "sslkey":
tlsArgs.sslKey = b[2] tlsArgs.sslKey = val
case "connect_timeout": case "connect_timeout":
timeout, err := strconv.ParseInt(b[2], 10, 64) timeout, err := strconv.ParseInt(val, 10, 64)
if err != nil { if err != nil {
return cp, err return cp, err
} }
@@ -1114,12 +1157,12 @@ func ParseDSN(s string) (ConnConfig, error) {
d.Timeout = time.Duration(timeout) * time.Second d.Timeout = time.Duration(timeout) * time.Second
cp.Dial = d.Dial cp.Dial = d.Dial
case "target_session_attrs": case "target_session_attrs":
cp.TargetSessionAttrs = TargetSessionType(b[2]) cp.TargetSessionAttrs = TargetSessionType(val)
if err := cp.TargetSessionAttrs.isValid(); err != nil { if err := cp.TargetSessionAttrs.isValid(); err != nil {
return cp, err return cp, err
} }
default: default:
cp.RuntimeParams[b[1]] = b[2] cp.RuntimeParams[key] = val
} }
} }
+88
View File
@@ -717,6 +717,38 @@ func TestParseDSN(t *testing.T) {
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
{
url: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb",
connParams: pgx.ConnConfig{
User: "jack's",
Password: "secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb",
connParams: pgx.ConnConfig{
User: "jack",
Password: "sooper\\secret",
Host: "localhost",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{ {
url: "user=jack host=localhost port=5432 dbname=mydb", url: "user=jack host=localhost port=5432 dbname=mydb",
connParams: pgx.ConnConfig{ connParams: pgx.ConnConfig{
@@ -822,6 +854,62 @@ func TestParseDSN(t *testing.T) {
TargetSessionAttrs: pgx.ReadWriteTargetSession, TargetSessionAttrs: pgx.ReadWriteTargetSession,
}, },
}, },
{
url: "user='jack' host='localhost' dbname='mydb'",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user='jack\\'s' host='localhost' dbname='mydb'",
connParams: pgx.ConnConfig{
User: "jack's",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user='jack' password='' host='localhost' dbname='mydb'",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
{
url: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb'",
connParams: pgx.ConnConfig{
User: "jack",
Host: "localhost",
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
UseFallbackTLS: true,
FallbackTLSConfig: nil,
RuntimeParams: map[string]string{},
},
},
} }
for i, tt := range tests { for i, tt := range tests {