1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-12-02 20:22:52 +00:00

improved protocol

This commit is contained in:
Laura
2025-09-12 21:39:17 +02:00
parent 05007ac4fe
commit f29707365f
7 changed files with 122 additions and 121 deletions

26
chat.go
View File

@@ -283,7 +283,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
for iteration := range raw.Iterations {
debug("iteration %d of %d", iteration+1, raw.Iterations)
response.Send(StartChunk())
response.WriteChunk(NewChunk(ChunkStart, nil))
if len(request.Tools) > 0 && iteration == raw.Iterations-1 {
debug("no more tool calls")
@@ -298,7 +298,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
tool, message, err := RunCompletion(ctx, response, request)
if err != nil {
response.Send(ErrorChunk(err))
response.WriteChunk(NewChunk(ChunkError, err))
return
}
@@ -311,27 +311,27 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("got %q tool call", tool.Name)
response.Send(ToolChunk(tool))
response.WriteChunk(NewChunk(ChunkTool, tool))
switch tool.Name {
case "search_web":
err = HandleSearchWebTool(ctx, tool)
if err != nil {
response.Send(ErrorChunk(err))
response.WriteChunk(NewChunk(ChunkError, err))
return
}
case "fetch_contents":
err = HandleFetchContentsTool(ctx, tool)
if err != nil {
response.Send(ErrorChunk(err))
response.WriteChunk(NewChunk(ChunkError, err))
return
}
case "github_repository":
err = HandleGitHubRepositoryTool(ctx, tool)
if err != nil {
response.Send(ErrorChunk(err))
response.WriteChunk(NewChunk(ChunkError, err))
return
}
@@ -344,14 +344,14 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("finished tool call")
response.Send(ToolChunk(tool))
response.WriteChunk(NewChunk(ChunkTool, tool))
request.Messages = append(request.Messages,
tool.AsAssistantToolCall(message),
tool.AsToolMessage(),
)
response.Send(EndChunk())
response.WriteChunk(NewChunk(ChunkEnd, nil))
}
}
@@ -386,7 +386,7 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
if id == "" {
id = chunk.ID
response.Send(IDChunk(id))
response.WriteChunk(NewChunk(ChunkID, id))
}
if len(chunk.Choices) == 0 {
@@ -397,7 +397,7 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
delta := choice.Delta
if choice.FinishReason == openrouter.FinishReasonContentFilter {
response.Send(ErrorChunk(errors.New("stopped due to content_filter")))
response.WriteChunk(NewChunk(ChunkError, errors.New("stopped due to content_filter")))
return nil, "", nil
}
@@ -434,16 +434,16 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
if delta.Content != "" {
buf.WriteString(delta.Content)
response.Send(TextChunk(delta.Content))
response.WriteChunk(NewChunk(ChunkText, delta.Content))
} else if delta.Reasoning != nil {
response.Send(ReasoningChunk(*delta.Reasoning))
response.WriteChunk(NewChunk(ChunkReasoning, *delta.Reasoning))
} else if len(delta.Images) > 0 {
for _, image := range delta.Images {
if image.Type != openrouter.StreamImageTypeImageURL {
continue
}
response.Send(ImageChunk(image.ImageURL.URL))
response.WriteChunk(NewChunk(ChunkImage, image.ImageURL.URL))
}
}
}

2
go.mod
View File

@@ -7,6 +7,7 @@ require (
github.com/go-chi/chi/v5 v5.2.3
github.com/goccy/go-yaml v1.18.0
github.com/revrost/go-openrouter v0.2.4
github.com/vmihailenco/msgpack/v5 v5.4.1
golang.org/x/crypto v0.42.0
)
@@ -16,6 +17,7 @@ require (
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/rs/zerolog v1.34.0 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.36.0 // indirect
golang.org/x/term v0.35.0 // indirect
)

4
go.sum
View File

@@ -29,6 +29,10 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -125,8 +125,9 @@
</div>
</div>
<script src="lib/highlight.min.js"></script>
<script src="lib/msgpackr.min.js"></script>
<script src="lib/marked.min.js"></script>
<script src="lib/highlight.min.js"></script>
<script src="lib/morphdom.min.js"></script>
<script src="js/lib.js"></script>
<script src="js/storage.js"></script>

View File

@@ -1,4 +1,15 @@
(() => {
const ChunkType = {
0: "start",
1: "id",
2: "reason",
3: "text",
4: "image",
5: "tool",
6: "error",
7: "end",
};
const $version = document.getElementById("version"),
$total = document.getElementById("total"),
$title = document.getElementById("title"),
@@ -975,50 +986,59 @@
throw new Error(err?.error || response.statusText);
}
const reader = response.body.getReader(),
decoder = new TextDecoder();
const reader = response.body.getReader();
let buffer = "";
let buffer = new Uint8Array();
while (true) {
const { value, done } = await reader.read();
if (done) break;
if (done) {
break;
}
buffer += decoder.decode(value, {
stream: true,
});
const read = new Uint8Array(buffer.length + value.length);
while (true) {
const idx = buffer.indexOf("\n\n");
read.set(buffer);
read.set(value, buffer.length);
if (idx === -1) {
break;
}
buffer = read;
const frame = buffer.slice(0, idx).trim();
buffer = buffer.slice(idx + 2);
while (buffer.length >= 5) {
const type = ChunkType[buffer[0]],
length = buffer[1] | (buffer[2] << 8) | (buffer[3] << 16) | (buffer[4] << 24);
if (!type) {
console.warn("bad chunk type", type);
buffer = buffer.slice(5 + length);
if (!frame) {
continue;
}
let chunk;
if (buffer.length < 5 + length) {
break;
}
try {
chunk = JSON.parse(frame);
let data;
if (!chunk) {
throw new Error("invalid chunk");
if (length > 0) {
const packed = buffer.slice(5, 5 + length);
try {
data = msgpackr.unpack(packed);
} catch (err) {
console.warn("bad chunk data", packed);
console.warn(err);
}
} catch (err) {
console.warn("bad frame", frame);
console.warn(err);
}
if (chunk) {
callback(chunk);
}
buffer = buffer.slice(5 + length);
callback({
type: type,
data: data,
});
}
}
} catch (err) {
@@ -1030,7 +1050,7 @@
callback({
type: "error",
text: err.message,
data: err.message,
});
} finally {
callback(aborted ? "aborted" : "done");
@@ -1201,36 +1221,36 @@
break;
case "id":
generationID = chunk.text;
generationID = chunk.data;
break;
case "tool":
message.setState("tooling");
message.setTool(chunk.text);
message.setTool(chunk.data);
if (chunk.text.done) {
totalCost += chunk.text.cost || 0;
if (chunk.data?.done) {
totalCost += chunk.data.cost || 0;
finish();
}
break;
case "image":
message.addImage(chunk.text);
message.addImage(chunk.data);
break;
case "reason":
message.setState("reasoning");
message.addReasoning(chunk.text);
message.addReasoning(chunk.data);
break;
case "text":
message.setState("receiving");
message.addText(chunk.text);
message.addText(chunk.data);
break;
case "error":
message.setError(chunk.text);
message.setError(chunk.data);
break;
}

2
static/lib/msgpackr.min.js vendored Normal file

File diff suppressed because one or more lines are too long

116
stream.go
View File

@@ -3,17 +3,31 @@ package main
import (
"bytes"
"context"
"encoding/json"
"encoding/binary"
"errors"
"net/http"
"sync"
"github.com/revrost/go-openrouter"
"github.com/vmihailenco/msgpack/v5"
)
const (
ChunkStart ChunkType = 0
ChunkID ChunkType = 1
ChunkReasoning ChunkType = 2
ChunkText ChunkType = 3
ChunkImage ChunkType = 4
ChunkTool ChunkType = 5
ChunkError ChunkType = 6
ChunkEnd ChunkType = 7
)
type ChunkType uint8
type Chunk struct {
Type string `json:"type"`
Text any `json:"text,omitempty"`
Type ChunkType
Data any
}
type Stream struct {
@@ -46,63 +60,10 @@ func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
}, nil
}
func (s *Stream) Send(ch Chunk) error {
debugIf(ch.Type == "error", "error: %v", ch.Text)
return WriteChunk(s.wr, s.ctx, ch)
}
func StartChunk() Chunk {
return Chunk{
Type: "start",
}
}
func IDChunk(id string) Chunk {
return Chunk{
Type: "id",
Text: id,
}
}
func ReasoningChunk(text string) Chunk {
return Chunk{
Type: "reason",
Text: text,
}
}
func TextChunk(text string) Chunk {
return Chunk{
Type: "text",
Text: CleanChunk(text),
}
}
func ImageChunk(image string) Chunk {
return Chunk{
Type: "image",
Text: image,
}
}
func ToolChunk(tool *ToolCall) Chunk {
return Chunk{
Type: "tool",
Text: tool,
}
}
func ErrorChunk(err error) Chunk {
return Chunk{
Type: "error",
Text: GetErrorMessage(err),
}
}
func EndChunk() Chunk {
return Chunk{
Type: "end",
func NewChunk(typ ChunkType, data any) *Chunk {
return &Chunk{
Type: typ,
Data: data,
}
}
@@ -114,32 +75,43 @@ func GetErrorMessage(err error) string {
return err.Error()
}
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error {
if err := ctx.Err(); err != nil {
func (s *Stream) WriteChunk(chunk *Chunk) error {
debugIf(chunk.Type == ChunkError, "error: %v", chunk.Data)
if err := s.ctx.Err(); err != nil {
return err
}
buf := GetFreeBuffer()
defer pool.Put(buf)
if err := json.NewEncoder(buf).Encode(chunk); err != nil {
binary.Write(buf, binary.LittleEndian, chunk.Type)
if chunk.Data != nil {
data, err := msgpack.Marshal(chunk.Data)
if err != nil {
return err
}
binary.Write(buf, binary.LittleEndian, uint32(len(data)))
buf.Write(data)
} else {
binary.Write(buf, binary.LittleEndian, uint32(0))
}
if _, err := s.wr.Write(buf.Bytes()); err != nil {
return err
}
buf.Write([]byte("\n\n"))
if _, err := w.Write(buf.Bytes()); err != nil {
return err
}
flusher, ok := w.(http.Flusher)
flusher, ok := s.wr.(http.Flusher)
if !ok {
return errors.New("failed to create flusher")
}
select {
case <-ctx.Done():
return ctx.Err()
case <-s.ctx.Done():
return s.ctx.Err()
default:
flusher.Flush()