1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-12-02 20:22:52 +00:00

improved protocol

This commit is contained in:
Laura
2025-09-12 21:39:17 +02:00
parent 05007ac4fe
commit f29707365f
7 changed files with 122 additions and 121 deletions

26
chat.go
View File

@@ -283,7 +283,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
for iteration := range raw.Iterations { for iteration := range raw.Iterations {
debug("iteration %d of %d", iteration+1, raw.Iterations) debug("iteration %d of %d", iteration+1, raw.Iterations)
response.Send(StartChunk()) response.WriteChunk(NewChunk(ChunkStart, nil))
if len(request.Tools) > 0 && iteration == raw.Iterations-1 { if len(request.Tools) > 0 && iteration == raw.Iterations-1 {
debug("no more tool calls") debug("no more tool calls")
@@ -298,7 +298,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
tool, message, err := RunCompletion(ctx, response, request) tool, message, err := RunCompletion(ctx, response, request)
if err != nil { if err != nil {
response.Send(ErrorChunk(err)) response.WriteChunk(NewChunk(ChunkError, err))
return return
} }
@@ -311,27 +311,27 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("got %q tool call", tool.Name) debug("got %q tool call", tool.Name)
response.Send(ToolChunk(tool)) response.WriteChunk(NewChunk(ChunkTool, tool))
switch tool.Name { switch tool.Name {
case "search_web": case "search_web":
err = HandleSearchWebTool(ctx, tool) err = HandleSearchWebTool(ctx, tool)
if err != nil { if err != nil {
response.Send(ErrorChunk(err)) response.WriteChunk(NewChunk(ChunkError, err))
return return
} }
case "fetch_contents": case "fetch_contents":
err = HandleFetchContentsTool(ctx, tool) err = HandleFetchContentsTool(ctx, tool)
if err != nil { if err != nil {
response.Send(ErrorChunk(err)) response.WriteChunk(NewChunk(ChunkError, err))
return return
} }
case "github_repository": case "github_repository":
err = HandleGitHubRepositoryTool(ctx, tool) err = HandleGitHubRepositoryTool(ctx, tool)
if err != nil { if err != nil {
response.Send(ErrorChunk(err)) response.WriteChunk(NewChunk(ChunkError, err))
return return
} }
@@ -344,14 +344,14 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("finished tool call") debug("finished tool call")
response.Send(ToolChunk(tool)) response.WriteChunk(NewChunk(ChunkTool, tool))
request.Messages = append(request.Messages, request.Messages = append(request.Messages,
tool.AsAssistantToolCall(message), tool.AsAssistantToolCall(message),
tool.AsToolMessage(), tool.AsToolMessage(),
) )
response.Send(EndChunk()) response.WriteChunk(NewChunk(ChunkEnd, nil))
} }
} }
@@ -386,7 +386,7 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
if id == "" { if id == "" {
id = chunk.ID id = chunk.ID
response.Send(IDChunk(id)) response.WriteChunk(NewChunk(ChunkID, id))
} }
if len(chunk.Choices) == 0 { if len(chunk.Choices) == 0 {
@@ -397,7 +397,7 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
delta := choice.Delta delta := choice.Delta
if choice.FinishReason == openrouter.FinishReasonContentFilter { if choice.FinishReason == openrouter.FinishReasonContentFilter {
response.Send(ErrorChunk(errors.New("stopped due to content_filter"))) response.WriteChunk(NewChunk(ChunkError, errors.New("stopped due to content_filter")))
return nil, "", nil return nil, "", nil
} }
@@ -434,16 +434,16 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
if delta.Content != "" { if delta.Content != "" {
buf.WriteString(delta.Content) buf.WriteString(delta.Content)
response.Send(TextChunk(delta.Content)) response.WriteChunk(NewChunk(ChunkText, delta.Content))
} else if delta.Reasoning != nil { } else if delta.Reasoning != nil {
response.Send(ReasoningChunk(*delta.Reasoning)) response.WriteChunk(NewChunk(ChunkReasoning, *delta.Reasoning))
} else if len(delta.Images) > 0 { } else if len(delta.Images) > 0 {
for _, image := range delta.Images { for _, image := range delta.Images {
if image.Type != openrouter.StreamImageTypeImageURL { if image.Type != openrouter.StreamImageTypeImageURL {
continue continue
} }
response.Send(ImageChunk(image.ImageURL.URL)) response.WriteChunk(NewChunk(ChunkImage, image.ImageURL.URL))
} }
} }
} }

2
go.mod
View File

@@ -7,6 +7,7 @@ require (
github.com/go-chi/chi/v5 v5.2.3 github.com/go-chi/chi/v5 v5.2.3
github.com/goccy/go-yaml v1.18.0 github.com/goccy/go-yaml v1.18.0
github.com/revrost/go-openrouter v0.2.4 github.com/revrost/go-openrouter v0.2.4
github.com/vmihailenco/msgpack/v5 v5.4.1
golang.org/x/crypto v0.42.0 golang.org/x/crypto v0.42.0
) )
@@ -16,6 +17,7 @@ require (
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/rs/zerolog v1.34.0 // indirect github.com/rs/zerolog v1.34.0 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.36.0 // indirect golang.org/x/sys v0.36.0 // indirect
golang.org/x/term v0.35.0 // indirect golang.org/x/term v0.35.0 // indirect
) )

4
go.sum
View File

@@ -29,6 +29,10 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -125,8 +125,9 @@
</div> </div>
</div> </div>
<script src="lib/highlight.min.js"></script> <script src="lib/msgpackr.min.js"></script>
<script src="lib/marked.min.js"></script> <script src="lib/marked.min.js"></script>
<script src="lib/highlight.min.js"></script>
<script src="lib/morphdom.min.js"></script> <script src="lib/morphdom.min.js"></script>
<script src="js/lib.js"></script> <script src="js/lib.js"></script>
<script src="js/storage.js"></script> <script src="js/storage.js"></script>

View File

@@ -1,4 +1,15 @@
(() => { (() => {
const ChunkType = {
0: "start",
1: "id",
2: "reason",
3: "text",
4: "image",
5: "tool",
6: "error",
7: "end",
};
const $version = document.getElementById("version"), const $version = document.getElementById("version"),
$total = document.getElementById("total"), $total = document.getElementById("total"),
$title = document.getElementById("title"), $title = document.getElementById("title"),
@@ -975,50 +986,59 @@
throw new Error(err?.error || response.statusText); throw new Error(err?.error || response.statusText);
} }
const reader = response.body.getReader(), const reader = response.body.getReader();
decoder = new TextDecoder();
let buffer = ""; let buffer = new Uint8Array();
while (true) { while (true) {
const { value, done } = await reader.read(); const { value, done } = await reader.read();
if (done) break; if (done) {
break;
}
buffer += decoder.decode(value, { const read = new Uint8Array(buffer.length + value.length);
stream: true,
});
while (true) { read.set(buffer);
const idx = buffer.indexOf("\n\n"); read.set(value, buffer.length);
if (idx === -1) { buffer = read;
break;
}
const frame = buffer.slice(0, idx).trim(); while (buffer.length >= 5) {
buffer = buffer.slice(idx + 2); const type = ChunkType[buffer[0]],
length = buffer[1] | (buffer[2] << 8) | (buffer[3] << 16) | (buffer[4] << 24);
if (!type) {
console.warn("bad chunk type", type);
buffer = buffer.slice(5 + length);
if (!frame) {
continue; continue;
} }
let chunk; if (buffer.length < 5 + length) {
break;
}
try { let data;
chunk = JSON.parse(frame);
if (!chunk) { if (length > 0) {
throw new Error("invalid chunk"); const packed = buffer.slice(5, 5 + length);
try {
data = msgpackr.unpack(packed);
} catch (err) {
console.warn("bad chunk data", packed);
console.warn(err);
} }
} catch (err) {
console.warn("bad frame", frame);
console.warn(err);
} }
if (chunk) { buffer = buffer.slice(5 + length);
callback(chunk);
} callback({
type: type,
data: data,
});
} }
} }
} catch (err) { } catch (err) {
@@ -1030,7 +1050,7 @@
callback({ callback({
type: "error", type: "error",
text: err.message, data: err.message,
}); });
} finally { } finally {
callback(aborted ? "aborted" : "done"); callback(aborted ? "aborted" : "done");
@@ -1201,36 +1221,36 @@
break; break;
case "id": case "id":
generationID = chunk.text; generationID = chunk.data;
break; break;
case "tool": case "tool":
message.setState("tooling"); message.setState("tooling");
message.setTool(chunk.text); message.setTool(chunk.data);
if (chunk.text.done) { if (chunk.data?.done) {
totalCost += chunk.text.cost || 0; totalCost += chunk.data.cost || 0;
finish(); finish();
} }
break; break;
case "image": case "image":
message.addImage(chunk.text); message.addImage(chunk.data);
break; break;
case "reason": case "reason":
message.setState("reasoning"); message.setState("reasoning");
message.addReasoning(chunk.text); message.addReasoning(chunk.data);
break; break;
case "text": case "text":
message.setState("receiving"); message.setState("receiving");
message.addText(chunk.text); message.addText(chunk.data);
break; break;
case "error": case "error":
message.setError(chunk.text); message.setError(chunk.data);
break; break;
} }

2
static/lib/msgpackr.min.js vendored Normal file

File diff suppressed because one or more lines are too long

116
stream.go
View File

@@ -3,17 +3,31 @@ package main
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/binary"
"errors" "errors"
"net/http" "net/http"
"sync" "sync"
"github.com/revrost/go-openrouter" "github.com/revrost/go-openrouter"
"github.com/vmihailenco/msgpack/v5"
) )
const (
ChunkStart ChunkType = 0
ChunkID ChunkType = 1
ChunkReasoning ChunkType = 2
ChunkText ChunkType = 3
ChunkImage ChunkType = 4
ChunkTool ChunkType = 5
ChunkError ChunkType = 6
ChunkEnd ChunkType = 7
)
type ChunkType uint8
type Chunk struct { type Chunk struct {
Type string `json:"type"` Type ChunkType
Text any `json:"text,omitempty"` Data any
} }
type Stream struct { type Stream struct {
@@ -46,63 +60,10 @@ func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
}, nil }, nil
} }
func (s *Stream) Send(ch Chunk) error { func NewChunk(typ ChunkType, data any) *Chunk {
debugIf(ch.Type == "error", "error: %v", ch.Text) return &Chunk{
Type: typ,
return WriteChunk(s.wr, s.ctx, ch) Data: data,
}
func StartChunk() Chunk {
return Chunk{
Type: "start",
}
}
func IDChunk(id string) Chunk {
return Chunk{
Type: "id",
Text: id,
}
}
func ReasoningChunk(text string) Chunk {
return Chunk{
Type: "reason",
Text: text,
}
}
func TextChunk(text string) Chunk {
return Chunk{
Type: "text",
Text: CleanChunk(text),
}
}
func ImageChunk(image string) Chunk {
return Chunk{
Type: "image",
Text: image,
}
}
func ToolChunk(tool *ToolCall) Chunk {
return Chunk{
Type: "tool",
Text: tool,
}
}
func ErrorChunk(err error) Chunk {
return Chunk{
Type: "error",
Text: GetErrorMessage(err),
}
}
func EndChunk() Chunk {
return Chunk{
Type: "end",
} }
} }
@@ -114,32 +75,43 @@ func GetErrorMessage(err error) string {
return err.Error() return err.Error()
} }
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error { func (s *Stream) WriteChunk(chunk *Chunk) error {
if err := ctx.Err(); err != nil { debugIf(chunk.Type == ChunkError, "error: %v", chunk.Data)
if err := s.ctx.Err(); err != nil {
return err return err
} }
buf := GetFreeBuffer() buf := GetFreeBuffer()
defer pool.Put(buf) defer pool.Put(buf)
if err := json.NewEncoder(buf).Encode(chunk); err != nil { binary.Write(buf, binary.LittleEndian, chunk.Type)
if chunk.Data != nil {
data, err := msgpack.Marshal(chunk.Data)
if err != nil {
return err
}
binary.Write(buf, binary.LittleEndian, uint32(len(data)))
buf.Write(data)
} else {
binary.Write(buf, binary.LittleEndian, uint32(0))
}
if _, err := s.wr.Write(buf.Bytes()); err != nil {
return err return err
} }
buf.Write([]byte("\n\n")) flusher, ok := s.wr.(http.Flusher)
if _, err := w.Write(buf.Bytes()); err != nil {
return err
}
flusher, ok := w.(http.Flusher)
if !ok { if !ok {
return errors.New("failed to create flusher") return errors.New("failed to create flusher")
} }
select { select {
case <-ctx.Done(): case <-s.ctx.Done():
return ctx.Err() return s.ctx.Err()
default: default:
flusher.Flush() flusher.Flush()