diff --git a/sse-encoder.go b/sse-encoder.go index 9ca9d7a..7afacd4 100644 --- a/sse-encoder.go +++ b/sse-encoder.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "reflect" + "strconv" "strings" ) @@ -19,6 +20,9 @@ import ( const ContentType = "text/event-stream" +var contentType = []string{ContentType} +var noCache = []string{"no-cache"} + type Event struct { Event string Id string @@ -26,63 +30,66 @@ type Event struct { Data interface{} } -func Encode(w io.Writer, event Event) error { +func Encode(writer io.Writer, event Event) error { + w := checkWriter(writer) writeId(w, event.Id) writeEvent(w, event.Event) writeRetry(w, event.Retry) return writeData(w, event.Data) } -func writeId(w io.Writer, id string) { +func writeId(w stringWriter, id string) { if len(id) > 0 { - w.Write([]byte("id: ")) - w.Write([]byte(escape(id))) - w.Write([]byte("\n")) + w.WriteString("id: ") + w.WriteString(escape(id)) + w.WriteString("\n") } } -func writeEvent(w io.Writer, event string) { +func writeEvent(w stringWriter, event string) { if len(event) > 0 { - w.Write([]byte("event: ")) - w.Write([]byte(escape(event))) - w.Write([]byte("\n")) + w.WriteString("event: ") + w.WriteString(escape(event)) + w.WriteString("\n") } } -func writeRetry(w io.Writer, retry uint) { +func writeRetry(w stringWriter, retry uint) { if retry > 0 { - fmt.Fprintf(w, "retry: %d\n", retry) + w.WriteString("retry: ") + w.WriteString(strconv.FormatUint(uint64(retry), 10)) + w.WriteString("\n") } } -func writeData(w io.Writer, data interface{}) error { - w.Write([]byte("data: ")) - switch typeOfData(data) { +func writeData(w stringWriter, data interface{}) error { + w.WriteString("data: ") + switch kindOfData(data) { case reflect.Struct, reflect.Slice, reflect.Map: err := json.NewEncoder(w).Encode(data) if err != nil { return err } - w.Write([]byte("\n")) + w.WriteString("\n") default: text := fmt.Sprint(data) - w.Write([]byte(escape(text))) - w.Write([]byte("\n\n")) + w.WriteString(escape(text)) + w.WriteString("\n\n") } return nil } func (r Event) Write(w http.ResponseWriter) error { header := w.Header() - header.Set("Content-Type", ContentType) + header["Content-Type"] = contentType if _, exist := header["Cache-Control"]; !exist { - header.Set("Cache-Control", "no-cache") + header["Cache-Control"] = noCache } return Encode(w, r) } -func typeOfData(data interface{}) reflect.Kind { +func kindOfData(data interface{}) reflect.Kind { value := reflect.ValueOf(data) valueType := value.Kind() if valueType == reflect.Ptr { diff --git a/writer.go b/writer.go new file mode 100644 index 0000000..6f9806c --- /dev/null +++ b/writer.go @@ -0,0 +1,24 @@ +package sse + +import "io" + +type stringWriter interface { + io.Writer + WriteString(string) (int, error) +} + +type stringWrapper struct { + io.Writer +} + +func (w stringWrapper) WriteString(str string) (int, error) { + return w.Writer.Write([]byte(str)) +} + +func checkWriter(writer io.Writer) stringWriter { + if w, ok := writer.(stringWriter); ok { + return w + } else { + return stringWrapper{writer} + } +}