From 3d629c93c5185be2bf824d4a04746acf413900a1 Mon Sep 17 00:00:00 2001 From: Laura Date: Fri, 29 Aug 2025 19:26:55 +0200 Subject: [PATCH] harden response streaming --- chat.go | 6 ++--- stream.go | 72 ++++++++++++++++++++++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 25 deletions(-) diff --git a/chat.go b/chat.go index 601418d..6dca2f1 100644 --- a/chat.go +++ b/chat.go @@ -256,7 +256,9 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { debug("preparing stream") - response, err := NewStream(w) + ctx := r.Context() + + response, err := NewStream(w, ctx) if err != nil { RespondJson(w, http.StatusBadRequest, map[string]any{ "error": err.Error(), @@ -267,8 +269,6 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { debug("handling request") - ctx := r.Context() - for iteration := range raw.Iterations { debug("iteration %d of %d", iteration+1, raw.Iterations) diff --git a/stream.go b/stream.go index a50181c..8338fc5 100644 --- a/stream.go +++ b/stream.go @@ -1,9 +1,12 @@ package main import ( + "bytes" + "context" "encoding/json" "errors" "net/http" + "sync" "github.com/revrost/go-openrouter" ) @@ -14,42 +17,31 @@ type Chunk struct { } type Stream struct { - wr http.ResponseWriter - fl http.Flusher - en *json.Encoder + wr http.ResponseWriter + ctx context.Context } -func NewStream(w http.ResponseWriter) (*Stream, error) { - flusher, ok := w.(http.Flusher) - if !ok { - return nil, errors.New("failed to create flusher") - } +var pool = sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, +} +func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") return &Stream{ - wr: w, - fl: flusher, - en: json.NewEncoder(w), + wr: w, + ctx: ctx, }, nil } func (s *Stream) Send(ch Chunk) error { debugIf(ch.Type == "error", "error: %v", ch.Text) - if err := s.en.Encode(ch); err != nil { - return err - } - - if _, err := s.wr.Write([]byte("\n\n")); err != nil { - return err - } - - s.fl.Flush() - - return nil + return WriteChunk(s.wr, s.ctx, ch) } func ReasoningChunk(text string) Chunk { @@ -94,3 +86,39 @@ func GetErrorMessage(err error) string { return err.Error() } + +func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error { + if err := ctx.Err(); err != nil { + return err + } + + buf := pool.Get().(*bytes.Buffer) + + buf.Reset() + + defer pool.Put(buf) + + if err := json.NewEncoder(buf).Encode(chunk); err != nil { + return err + } + + buf.Write([]byte("\n\n")) + + if _, err := w.Write(buf.Bytes()); err != nil { + return err + } + + flusher, ok := w.(http.Flusher) + if !ok { + return errors.New("failed to create flusher") + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + flusher.Flush() + + return nil + } +}