2
0

validate all addresses resolved from hostname

Signed-off-by: Artemiy Ryabinkov <getlag@ya.ru>
This commit is contained in:
Artemiy Ryabinkov
2019-09-13 17:26:09 +03:00
parent 51cf0d5480
commit b2ca5d8f52
2 changed files with 41 additions and 1 deletions
+8 -1
View File
@@ -37,6 +37,7 @@ type Config struct {
Password string
TLSConfig *tls.Config // nil disables TLS
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
@@ -77,7 +78,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
network = "tcp"
address = fmt.Sprintf("%s:%d", host, port)
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
return network, address
}
@@ -190,6 +191,8 @@ func ParseConfig(connString string) (*Config, error) {
config.DialFunc = defaultDialer.DialContext
}
config.LookupFunc = makeDefaultResolver().LookupHost
notRuntimeParams := map[string]struct{}{
"host": struct{}{},
"port": struct{}{},
@@ -495,6 +498,10 @@ func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultResolver() *net.Resolver {
return net.DefaultResolver
}
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
return func(r io.Reader, w io.Writer) Frontend {
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
+33
View File
@@ -43,6 +43,9 @@ type Notification struct {
// DialFunc is a function that can be used to connect to a PostgreSQL server.
type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error)
// LookupFunc is a function that can be used to lookup IPs addrs from host.
type LookupFunc func(ctx context.Context, host string) (addrs []string, err error)
// BuildFrontendFunc is a function that can be used to create Frontend implementation for connection.
type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend
@@ -123,6 +126,15 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
}
fallbackConfigs = append(fallbackConfigs, config.Fallbacks...)
fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs)
if err != nil {
return nil, &connectError{config: config, msg: "hostname resolving error", err: err}
}
if len(fallbackConfigs) == 0 {
return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")}
}
for _, fc := range fallbackConfigs {
pgConn, err = connect(ctx, config, fc)
if err == nil {
@@ -147,6 +159,27 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
return pgConn, nil
}
func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) {
var configs []*FallbackConfig
for _, fb := range fallbacks {
ips, err := lookupFn(ctx, fb.Host)
if err != nil {
return nil, err
}
for _, ip := range ips {
configs = append(configs, &FallbackConfig{
Host: ip,
Port: fb.Port,
TLSConfig: fb.TLSConfig,
})
}
}
return configs, nil
}
func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) {
pgConn := new(PgConn)
pgConn.config = config