mirror of
https://github.com/coalaura/whiskr.git
synced 2025-09-08 17:06:42 +00:00
harden response streaming
This commit is contained in:
6
chat.go
6
chat.go
@@ -256,7 +256,9 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
debug("preparing stream")
|
||||
|
||||
response, err := NewStream(w)
|
||||
ctx := r.Context()
|
||||
|
||||
response, err := NewStream(w, ctx)
|
||||
if err != nil {
|
||||
RespondJson(w, http.StatusBadRequest, map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -267,8 +269,6 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
debug("handling request")
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
for iteration := range raw.Iterations {
|
||||
debug("iteration %d of %d", iteration+1, raw.Iterations)
|
||||
|
||||
|
66
stream.go
66
stream.go
@@ -1,9 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/revrost/go-openrouter"
|
||||
)
|
||||
@@ -15,41 +18,30 @@ type Chunk struct {
|
||||
|
||||
type Stream struct {
|
||||
wr http.ResponseWriter
|
||||
fl http.Flusher
|
||||
en *json.Encoder
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewStream(w http.ResponseWriter) (*Stream, error) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to create flusher")
|
||||
var pool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
}
|
||||
|
||||
func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
return &Stream{
|
||||
wr: w,
|
||||
fl: flusher,
|
||||
en: json.NewEncoder(w),
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Stream) Send(ch Chunk) error {
|
||||
debugIf(ch.Type == "error", "error: %v", ch.Text)
|
||||
|
||||
if err := s.en.Encode(ch); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := s.wr.Write([]byte("\n\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.fl.Flush()
|
||||
|
||||
return nil
|
||||
return WriteChunk(s.wr, s.ctx, ch)
|
||||
}
|
||||
|
||||
func ReasoningChunk(text string) Chunk {
|
||||
@@ -94,3 +86,39 @@ func GetErrorMessage(err error) string {
|
||||
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := pool.Get().(*bytes.Buffer)
|
||||
|
||||
buf.Reset()
|
||||
|
||||
defer pool.Put(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf.Write([]byte("\n\n"))
|
||||
|
||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("failed to create flusher")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
flusher.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user