1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-12-02 20:22:52 +00:00

improved protocol

This commit is contained in:
Laura
2025-09-12 21:39:17 +02:00
parent 05007ac4fe
commit f29707365f
7 changed files with 122 additions and 121 deletions

116
stream.go
View File

@@ -3,17 +3,31 @@ package main
import (
"bytes"
"context"
"encoding/json"
"encoding/binary"
"errors"
"net/http"
"sync"
"github.com/revrost/go-openrouter"
"github.com/vmihailenco/msgpack/v5"
)
const (
ChunkStart ChunkType = 0
ChunkID ChunkType = 1
ChunkReasoning ChunkType = 2
ChunkText ChunkType = 3
ChunkImage ChunkType = 4
ChunkTool ChunkType = 5
ChunkError ChunkType = 6
ChunkEnd ChunkType = 7
)
type ChunkType uint8
type Chunk struct {
Type string `json:"type"`
Text any `json:"text,omitempty"`
Type ChunkType
Data any
}
type Stream struct {
@@ -46,63 +60,10 @@ func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
}, nil
}
func (s *Stream) Send(ch Chunk) error {
debugIf(ch.Type == "error", "error: %v", ch.Text)
return WriteChunk(s.wr, s.ctx, ch)
}
func StartChunk() Chunk {
return Chunk{
Type: "start",
}
}
func IDChunk(id string) Chunk {
return Chunk{
Type: "id",
Text: id,
}
}
func ReasoningChunk(text string) Chunk {
return Chunk{
Type: "reason",
Text: text,
}
}
func TextChunk(text string) Chunk {
return Chunk{
Type: "text",
Text: CleanChunk(text),
}
}
func ImageChunk(image string) Chunk {
return Chunk{
Type: "image",
Text: image,
}
}
func ToolChunk(tool *ToolCall) Chunk {
return Chunk{
Type: "tool",
Text: tool,
}
}
func ErrorChunk(err error) Chunk {
return Chunk{
Type: "error",
Text: GetErrorMessage(err),
}
}
func EndChunk() Chunk {
return Chunk{
Type: "end",
func NewChunk(typ ChunkType, data any) *Chunk {
return &Chunk{
Type: typ,
Data: data,
}
}
@@ -114,32 +75,43 @@ func GetErrorMessage(err error) string {
return err.Error()
}
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error {
if err := ctx.Err(); err != nil {
func (s *Stream) WriteChunk(chunk *Chunk) error {
debugIf(chunk.Type == ChunkError, "error: %v", chunk.Data)
if err := s.ctx.Err(); err != nil {
return err
}
buf := GetFreeBuffer()
defer pool.Put(buf)
if err := json.NewEncoder(buf).Encode(chunk); err != nil {
binary.Write(buf, binary.LittleEndian, chunk.Type)
if chunk.Data != nil {
data, err := msgpack.Marshal(chunk.Data)
if err != nil {
return err
}
binary.Write(buf, binary.LittleEndian, uint32(len(data)))
buf.Write(data)
} else {
binary.Write(buf, binary.LittleEndian, uint32(0))
}
if _, err := s.wr.Write(buf.Bytes()); err != nil {
return err
}
buf.Write([]byte("\n\n"))
if _, err := w.Write(buf.Bytes()); err != nil {
return err
}
flusher, ok := w.(http.Flusher)
flusher, ok := s.wr.(http.Flusher)
if !ok {
return errors.New("failed to create flusher")
}
select {
case <-ctx.Done():
return ctx.Err()
case <-s.ctx.Done():
return s.ctx.Err()
default:
flusher.Flush()