mirror of
https://github.com/coalaura/whiskr.git
synced 2025-09-09 09:19:54 +00:00
harden response streaming
This commit is contained in:
72
stream.go
72
stream.go
@@ -1,9 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/revrost/go-openrouter"
|
||||
)
|
||||
@@ -14,42 +17,31 @@ type Chunk struct {
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
wr http.ResponseWriter
|
||||
fl http.Flusher
|
||||
en *json.Encoder
|
||||
wr http.ResponseWriter
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewStream(w http.ResponseWriter) (*Stream, error) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to create flusher")
|
||||
}
|
||||
var pool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
}
|
||||
|
||||
func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
|
||||
return &Stream{
|
||||
wr: w,
|
||||
fl: flusher,
|
||||
en: json.NewEncoder(w),
|
||||
wr: w,
|
||||
ctx: ctx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Stream) Send(ch Chunk) error {
|
||||
debugIf(ch.Type == "error", "error: %v", ch.Text)
|
||||
|
||||
if err := s.en.Encode(ch); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := s.wr.Write([]byte("\n\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.fl.Flush()
|
||||
|
||||
return nil
|
||||
return WriteChunk(s.wr, s.ctx, ch)
|
||||
}
|
||||
|
||||
func ReasoningChunk(text string) Chunk {
|
||||
@@ -94,3 +86,39 @@ func GetErrorMessage(err error) string {
|
||||
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := pool.Get().(*bytes.Buffer)
|
||||
|
||||
buf.Reset()
|
||||
|
||||
defer pool.Put(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf.Write([]byte("\n\n"))
|
||||
|
||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("failed to create flusher")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
flusher.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user