mirror of
https://github.com/coalaura/whiskr.git
synced 2025-09-08 17:06:42 +00:00
harden response streaming
This commit is contained in:
6
chat.go
6
chat.go
@@ -256,7 +256,9 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
debug("preparing stream")
|
debug("preparing stream")
|
||||||
|
|
||||||
response, err := NewStream(w)
|
ctx := r.Context()
|
||||||
|
|
||||||
|
response, err := NewStream(w, ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RespondJson(w, http.StatusBadRequest, map[string]any{
|
RespondJson(w, http.StatusBadRequest, map[string]any{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
@@ -267,8 +269,6 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
debug("handling request")
|
debug("handling request")
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
for iteration := range raw.Iterations {
|
for iteration := range raw.Iterations {
|
||||||
debug("iteration %d of %d", iteration+1, raw.Iterations)
|
debug("iteration %d of %d", iteration+1, raw.Iterations)
|
||||||
|
|
||||||
|
72
stream.go
72
stream.go
@@ -1,9 +1,12 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/revrost/go-openrouter"
|
"github.com/revrost/go-openrouter"
|
||||||
)
|
)
|
||||||
@@ -14,42 +17,31 @@ type Chunk struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Stream struct {
|
type Stream struct {
|
||||||
wr http.ResponseWriter
|
wr http.ResponseWriter
|
||||||
fl http.Flusher
|
ctx context.Context
|
||||||
en *json.Encoder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStream(w http.ResponseWriter) (*Stream, error) {
|
var pool = sync.Pool{
|
||||||
flusher, ok := w.(http.Flusher)
|
New: func() interface{} {
|
||||||
if !ok {
|
return &bytes.Buffer{}
|
||||||
return nil, errors.New("failed to create flusher")
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
|
||||||
w.Header().Set("Content-Type", "text/event-stream")
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
w.Header().Set("Cache-Control", "no-cache")
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
w.Header().Set("Connection", "keep-alive")
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
|
||||||
return &Stream{
|
return &Stream{
|
||||||
wr: w,
|
wr: w,
|
||||||
fl: flusher,
|
ctx: ctx,
|
||||||
en: json.NewEncoder(w),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stream) Send(ch Chunk) error {
|
func (s *Stream) Send(ch Chunk) error {
|
||||||
debugIf(ch.Type == "error", "error: %v", ch.Text)
|
debugIf(ch.Type == "error", "error: %v", ch.Text)
|
||||||
|
|
||||||
if err := s.en.Encode(ch); err != nil {
|
return WriteChunk(s.wr, s.ctx, ch)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := s.wr.Write([]byte("\n\n")); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.fl.Flush()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReasoningChunk(text string) Chunk {
|
func ReasoningChunk(text string) Chunk {
|
||||||
@@ -94,3 +86,39 @@ func GetErrorMessage(err error) string {
|
|||||||
|
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error {
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := pool.Get().(*bytes.Buffer)
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
|
||||||
|
defer pool.Put(buf)
|
||||||
|
|
||||||
|
if err := json.NewEncoder(buf).Encode(chunk); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.Write([]byte("\n\n"))
|
||||||
|
|
||||||
|
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("failed to create flusher")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
flusher.Flush()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user