1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-09-08 17:06:42 +00:00

harden response streaming

This commit is contained in:
2025-08-29 19:26:55 +02:00
parent 58aa250abe
commit 3d629c93c5
2 changed files with 53 additions and 25 deletions

View File

@@ -256,7 +256,9 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("preparing stream") debug("preparing stream")
response, err := NewStream(w) ctx := r.Context()
response, err := NewStream(w, ctx)
if err != nil { if err != nil {
RespondJson(w, http.StatusBadRequest, map[string]any{ RespondJson(w, http.StatusBadRequest, map[string]any{
"error": err.Error(), "error": err.Error(),
@@ -267,8 +269,6 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("handling request") debug("handling request")
ctx := r.Context()
for iteration := range raw.Iterations { for iteration := range raw.Iterations {
debug("iteration %d of %d", iteration+1, raw.Iterations) debug("iteration %d of %d", iteration+1, raw.Iterations)

View File

@@ -1,9 +1,12 @@
package main package main
import ( import (
"bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net/http" "net/http"
"sync"
"github.com/revrost/go-openrouter" "github.com/revrost/go-openrouter"
) )
@@ -14,42 +17,31 @@ type Chunk struct {
} }
type Stream struct { type Stream struct {
wr http.ResponseWriter wr http.ResponseWriter
fl http.Flusher ctx context.Context
en *json.Encoder
} }
func NewStream(w http.ResponseWriter) (*Stream, error) { var pool = sync.Pool{
flusher, ok := w.(http.Flusher) New: func() interface{} {
if !ok { return &bytes.Buffer{}
return nil, errors.New("failed to create flusher") },
} }
func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive") w.Header().Set("Connection", "keep-alive")
return &Stream{ return &Stream{
wr: w, wr: w,
fl: flusher, ctx: ctx,
en: json.NewEncoder(w),
}, nil }, nil
} }
func (s *Stream) Send(ch Chunk) error { func (s *Stream) Send(ch Chunk) error {
debugIf(ch.Type == "error", "error: %v", ch.Text) debugIf(ch.Type == "error", "error: %v", ch.Text)
if err := s.en.Encode(ch); err != nil { return WriteChunk(s.wr, s.ctx, ch)
return err
}
if _, err := s.wr.Write([]byte("\n\n")); err != nil {
return err
}
s.fl.Flush()
return nil
} }
func ReasoningChunk(text string) Chunk { func ReasoningChunk(text string) Chunk {
@@ -94,3 +86,39 @@ func GetErrorMessage(err error) string {
return err.Error() return err.Error()
} }
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error {
if err := ctx.Err(); err != nil {
return err
}
buf := pool.Get().(*bytes.Buffer)
buf.Reset()
defer pool.Put(buf)
if err := json.NewEncoder(buf).Encode(chunk); err != nil {
return err
}
buf.Write([]byte("\n\n"))
if _, err := w.Write(buf.Bytes()); err != nil {
return err
}
flusher, ok := w.(http.Flusher)
if !ok {
return errors.New("failed to create flusher")
}
select {
case <-ctx.Done():
return ctx.Err()
default:
flusher.Flush()
return nil
}
}