From b2ca5d8f521597a28e8dc0703b9b2a8c72d9866a Mon Sep 17 00:00:00 2001 From: Artemiy Ryabinkov Date: Fri, 13 Sep 2019 17:26:09 +0300 Subject: [PATCH] validate all addresses resolved from hostname Signed-off-by: Artemiy Ryabinkov --- config.go | 9 ++++++++- pgconn.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/config.go b/config.go index 2ec6ae3f..57e65e13 100644 --- a/config.go +++ b/config.go @@ -37,6 +37,7 @@ type Config struct { Password string TLSConfig *tls.Config // nil disables TLS DialFunc DialFunc // e.g. net.Dialer.DialContext + LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) @@ -77,7 +78,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10) } else { network = "tcp" - address = fmt.Sprintf("%s:%d", host, port) + address = net.JoinHostPort(host, strconv.Itoa(int(port))) } return network, address } @@ -190,6 +191,8 @@ func ParseConfig(connString string) (*Config, error) { config.DialFunc = defaultDialer.DialContext } + config.LookupFunc = makeDefaultResolver().LookupHost + notRuntimeParams := map[string]struct{}{ "host": struct{}{}, "port": struct{}{}, @@ -495,6 +498,10 @@ func makeDefaultDialer() *net.Dialer { return &net.Dialer{KeepAlive: 5 * time.Minute} } +func makeDefaultResolver() *net.Resolver { + return net.DefaultResolver +} + func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc { return func(r io.Reader, w io.Writer) Frontend { cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen}) diff --git a/pgconn.go b/pgconn.go index 5c01d1dc..db2ebe73 100644 --- a/pgconn.go +++ b/pgconn.go @@ -43,6 +43,9 @@ type Notification struct { // DialFunc is a function that can be used to connect to a PostgreSQL server. type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +// LookupFunc is a function that can be used to lookup IPs addrs from host. +type LookupFunc func(ctx context.Context, host string) (addrs []string, err error) + // BuildFrontendFunc is a function that can be used to create Frontend implementation for connection. type BuildFrontendFunc func(r io.Reader, w io.Writer) Frontend @@ -123,6 +126,15 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err } fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) + if err != nil { + return nil, &connectError{config: config, msg: "hostname resolving error", err: err} + } + + if len(fallbackConfigs) == 0 { + return nil, &connectError{config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} + } + for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { @@ -147,6 +159,27 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err return pgConn, nil } +func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { + var configs []*FallbackConfig + + for _, fb := range fallbacks { + ips, err := lookupFn(ctx, fb.Host) + if err != nil { + return nil, err + } + + for _, ip := range ips { + configs = append(configs, &FallbackConfig{ + Host: ip, + Port: fb.Port, + TLSConfig: fb.TLSConfig, + }) + } + } + + return configs, nil +} + func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config