diff --git a/config.go b/config.go index 299d4784..7ed99096 100644 --- a/config.go +++ b/config.go @@ -62,6 +62,35 @@ type Config struct { createdByParseConfig bool // Used to enforce created by ParseConfig rule. } +// Copy returns a deep copy of the config that is safe to use and modify. +// The only exception is the TLSConfig field: +// according to the tls.Config docs it must not be modified after creation. +func (c *Config) Copy() *Config { + newConf := new(Config) + *newConf = *c + if newConf.TLSConfig != nil { + newConf.TLSConfig = c.TLSConfig.Clone() + } + if newConf.RuntimeParams != nil { + newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) + for k, v := range c.RuntimeParams { + newConf.RuntimeParams[k] = v + } + } + if newConf.Fallbacks != nil { + newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) + for i, fallback := range c.Fallbacks { + newFallback := new(FallbackConfig) + *newFallback = *fallback + if newFallback.TLSConfig != nil { + newFallback.TLSConfig = fallback.TLSConfig.Clone() + } + newConf.Fallbacks[i] = newFallback + } + } + return newConf +} + // FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a // network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections. type FallbackConfig struct { diff --git a/config_test.go b/config_test.go index 515ea6d3..ebe627b1 100644 --- a/config_test.go +++ b/config_test.go @@ -1,6 +1,7 @@ package pgconn_test import ( + "context" "crypto/tls" "fmt" "io/ioutil" @@ -527,6 +528,44 @@ func TestParseConfig(t *testing.T) { } } +func TestConfigCopyReturnsEqualConfig(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") +} + +func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) { + connString := "postgres://jack:secret@localhost:5432/mydb?application_name=pgxtest&search_path=myschema&connect_timeout=5" + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assertConfigsEqual(t, original, copied, "Test Config.Copy() returns equal config") + + copied.Port = uint16(5433) + copied.RuntimeParams["foo"] = "bar" + copied.Fallbacks[0].Port = uint16(5433) + + assert.Equal(t, uint16(5432), original.Port) + assert.Equal(t, "", original.RuntimeParams["foo"]) + assert.Equal(t, uint16(5432), original.Fallbacks[0].Port) +} + +func TestConfigCopyCanBeUsedToConnect(t *testing.T) { + connString := os.Getenv("PGX_TEST_CONN_STRING") + original, err := pgconn.ParseConfig(connString) + require.NoError(t, err) + + copied := original.Copy() + assert.NotPanics(t, func() { + _, err = pgconn.ConnectConfig(context.Background(), copied) + }) + assert.NoError(t, err) +} + func assertConfigsEqual(t *testing.T, expected, actual *pgconn.Config, testName string) { if !assert.NotNil(t, expected) { return