diff --git a/logger.go b/logger.go index 53bbf95..885f150 100644 --- a/logger.go +++ b/logger.go @@ -337,8 +337,11 @@ func (logger *Logger) AddHook(hook Hook) { logger.Hooks.Add(hook) } -func (logger *Logger) ReplaceHooks(hooks LevelHooks) { +// ReplaceHooks replaces the logger hooks and returns the old ones +func (logger *Logger) ReplaceHooks(hooks LevelHooks) LevelHooks { logger.mu.Lock() + oldHooks := logger.Hooks logger.Hooks = hooks logger.mu.Unlock() + return oldHooks } diff --git a/logrus_test.go b/logrus_test.go index 7a96686..f6db6e9 100644 --- a/logrus_test.go +++ b/logrus_test.go @@ -3,6 +3,7 @@ package logrus import ( "bytes" "encoding/json" + "io/ioutil" "strconv" "strings" "sync" @@ -421,20 +422,25 @@ func TestLoggingRaceWithHooksOnEntry(t *testing.T) { wg.Wait() } -func TestHooksReplace(t *testing.T) { +func TestReplaceHooks(t *testing.T) { old, cur := &TestHook{}, &TestHook{} logger := New() + logger.SetOutput(ioutil.Discard) logger.AddHook(old) hooks := make(LevelHooks) hooks.Add(cur) - logger.ReplaceHooks(hooks) + replaced := logger.ReplaceHooks(hooks) logger.Info("test") assert.Equal(t, old.Fired, false) assert.Equal(t, cur.Fired, true) + + logger.ReplaceHooks(replaced) + logger.Info("test") + assert.Equal(t, old.Fired, true) } // Compile test