mirror of
https://github.com/coalaura/whiskr.git
synced 2025-12-02 20:22:52 +00:00
improved protocol
This commit is contained in:
26
chat.go
26
chat.go
@@ -283,7 +283,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
for iteration := range raw.Iterations {
|
||||
debug("iteration %d of %d", iteration+1, raw.Iterations)
|
||||
|
||||
response.Send(StartChunk())
|
||||
response.WriteChunk(NewChunk(ChunkStart, nil))
|
||||
|
||||
if len(request.Tools) > 0 && iteration == raw.Iterations-1 {
|
||||
debug("no more tool calls")
|
||||
@@ -298,7 +298,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
tool, message, err := RunCompletion(ctx, response, request)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
@@ -311,27 +311,27 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
debug("got %q tool call", tool.Name)
|
||||
|
||||
response.Send(ToolChunk(tool))
|
||||
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||
|
||||
switch tool.Name {
|
||||
case "search_web":
|
||||
err = HandleSearchWebTool(ctx, tool)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
case "fetch_contents":
|
||||
err = HandleFetchContentsTool(ctx, tool)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
case "github_repository":
|
||||
err = HandleGitHubRepositoryTool(ctx, tool)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
@@ -344,14 +344,14 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
debug("finished tool call")
|
||||
|
||||
response.Send(ToolChunk(tool))
|
||||
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||
|
||||
request.Messages = append(request.Messages,
|
||||
tool.AsAssistantToolCall(message),
|
||||
tool.AsToolMessage(),
|
||||
)
|
||||
|
||||
response.Send(EndChunk())
|
||||
response.WriteChunk(NewChunk(ChunkEnd, nil))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -386,7 +386,7 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
|
||||
if id == "" {
|
||||
id = chunk.ID
|
||||
|
||||
response.Send(IDChunk(id))
|
||||
response.WriteChunk(NewChunk(ChunkID, id))
|
||||
}
|
||||
|
||||
if len(chunk.Choices) == 0 {
|
||||
@@ -397,7 +397,7 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
|
||||
delta := choice.Delta
|
||||
|
||||
if choice.FinishReason == openrouter.FinishReasonContentFilter {
|
||||
response.Send(ErrorChunk(errors.New("stopped due to content_filter")))
|
||||
response.WriteChunk(NewChunk(ChunkError, errors.New("stopped due to content_filter")))
|
||||
|
||||
return nil, "", nil
|
||||
}
|
||||
@@ -434,16 +434,16 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
|
||||
if delta.Content != "" {
|
||||
buf.WriteString(delta.Content)
|
||||
|
||||
response.Send(TextChunk(delta.Content))
|
||||
response.WriteChunk(NewChunk(ChunkText, delta.Content))
|
||||
} else if delta.Reasoning != nil {
|
||||
response.Send(ReasoningChunk(*delta.Reasoning))
|
||||
response.WriteChunk(NewChunk(ChunkReasoning, *delta.Reasoning))
|
||||
} else if len(delta.Images) > 0 {
|
||||
for _, image := range delta.Images {
|
||||
if image.Type != openrouter.StreamImageTypeImageURL {
|
||||
continue
|
||||
}
|
||||
|
||||
response.Send(ImageChunk(image.ImageURL.URL))
|
||||
response.WriteChunk(NewChunk(ChunkImage, image.ImageURL.URL))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -7,6 +7,7 @@ require (
|
||||
github.com/go-chi/chi/v5 v5.2.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/revrost/go-openrouter v0.2.4
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1
|
||||
golang.org/x/crypto v0.42.0
|
||||
)
|
||||
|
||||
@@ -16,6 +17,7 @@ require (
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/rs/zerolog v1.34.0 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
golang.org/x/sys v0.36.0 // indirect
|
||||
golang.org/x/term v0.35.0 // indirect
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -29,6 +29,10 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
||||
@@ -125,8 +125,9 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script src="lib/highlight.min.js"></script>
|
||||
<script src="lib/msgpackr.min.js"></script>
|
||||
<script src="lib/marked.min.js"></script>
|
||||
<script src="lib/highlight.min.js"></script>
|
||||
<script src="lib/morphdom.min.js"></script>
|
||||
<script src="js/lib.js"></script>
|
||||
<script src="js/storage.js"></script>
|
||||
|
||||
@@ -1,4 +1,15 @@
|
||||
(() => {
|
||||
const ChunkType = {
|
||||
0: "start",
|
||||
1: "id",
|
||||
2: "reason",
|
||||
3: "text",
|
||||
4: "image",
|
||||
5: "tool",
|
||||
6: "error",
|
||||
7: "end",
|
||||
};
|
||||
|
||||
const $version = document.getElementById("version"),
|
||||
$total = document.getElementById("total"),
|
||||
$title = document.getElementById("title"),
|
||||
@@ -975,50 +986,59 @@
|
||||
throw new Error(err?.error || response.statusText);
|
||||
}
|
||||
|
||||
const reader = response.body.getReader(),
|
||||
decoder = new TextDecoder();
|
||||
const reader = response.body.getReader();
|
||||
|
||||
let buffer = "";
|
||||
let buffer = new Uint8Array();
|
||||
|
||||
while (true) {
|
||||
const { value, done } = await reader.read();
|
||||
|
||||
if (done) break;
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, {
|
||||
stream: true,
|
||||
});
|
||||
const read = new Uint8Array(buffer.length + value.length);
|
||||
|
||||
while (true) {
|
||||
const idx = buffer.indexOf("\n\n");
|
||||
read.set(buffer);
|
||||
read.set(value, buffer.length);
|
||||
|
||||
if (idx === -1) {
|
||||
break;
|
||||
}
|
||||
buffer = read;
|
||||
|
||||
const frame = buffer.slice(0, idx).trim();
|
||||
buffer = buffer.slice(idx + 2);
|
||||
while (buffer.length >= 5) {
|
||||
const type = ChunkType[buffer[0]],
|
||||
length = buffer[1] | (buffer[2] << 8) | (buffer[3] << 16) | (buffer[4] << 24);
|
||||
|
||||
if (!type) {
|
||||
console.warn("bad chunk type", type);
|
||||
|
||||
buffer = buffer.slice(5 + length);
|
||||
|
||||
if (!frame) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let chunk;
|
||||
if (buffer.length < 5 + length) {
|
||||
break;
|
||||
}
|
||||
|
||||
try {
|
||||
chunk = JSON.parse(frame);
|
||||
let data;
|
||||
|
||||
if (!chunk) {
|
||||
throw new Error("invalid chunk");
|
||||
if (length > 0) {
|
||||
const packed = buffer.slice(5, 5 + length);
|
||||
|
||||
try {
|
||||
data = msgpackr.unpack(packed);
|
||||
} catch (err) {
|
||||
console.warn("bad chunk data", packed);
|
||||
console.warn(err);
|
||||
}
|
||||
} catch (err) {
|
||||
console.warn("bad frame", frame);
|
||||
console.warn(err);
|
||||
}
|
||||
|
||||
if (chunk) {
|
||||
callback(chunk);
|
||||
}
|
||||
buffer = buffer.slice(5 + length);
|
||||
|
||||
callback({
|
||||
type: type,
|
||||
data: data,
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
@@ -1030,7 +1050,7 @@
|
||||
|
||||
callback({
|
||||
type: "error",
|
||||
text: err.message,
|
||||
data: err.message,
|
||||
});
|
||||
} finally {
|
||||
callback(aborted ? "aborted" : "done");
|
||||
@@ -1201,36 +1221,36 @@
|
||||
|
||||
break;
|
||||
case "id":
|
||||
generationID = chunk.text;
|
||||
generationID = chunk.data;
|
||||
|
||||
break;
|
||||
case "tool":
|
||||
message.setState("tooling");
|
||||
message.setTool(chunk.text);
|
||||
message.setTool(chunk.data);
|
||||
|
||||
if (chunk.text.done) {
|
||||
totalCost += chunk.text.cost || 0;
|
||||
if (chunk.data?.done) {
|
||||
totalCost += chunk.data.cost || 0;
|
||||
|
||||
finish();
|
||||
}
|
||||
|
||||
break;
|
||||
case "image":
|
||||
message.addImage(chunk.text);
|
||||
message.addImage(chunk.data);
|
||||
|
||||
break;
|
||||
case "reason":
|
||||
message.setState("reasoning");
|
||||
message.addReasoning(chunk.text);
|
||||
message.addReasoning(chunk.data);
|
||||
|
||||
break;
|
||||
case "text":
|
||||
message.setState("receiving");
|
||||
message.addText(chunk.text);
|
||||
message.addText(chunk.data);
|
||||
|
||||
break;
|
||||
case "error":
|
||||
message.setError(chunk.text);
|
||||
message.setError(chunk.data);
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
2
static/lib/msgpackr.min.js
vendored
Normal file
2
static/lib/msgpackr.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
116
stream.go
116
stream.go
@@ -3,17 +3,31 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/revrost/go-openrouter"
|
||||
"github.com/vmihailenco/msgpack/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
ChunkStart ChunkType = 0
|
||||
ChunkID ChunkType = 1
|
||||
ChunkReasoning ChunkType = 2
|
||||
ChunkText ChunkType = 3
|
||||
ChunkImage ChunkType = 4
|
||||
ChunkTool ChunkType = 5
|
||||
ChunkError ChunkType = 6
|
||||
ChunkEnd ChunkType = 7
|
||||
)
|
||||
|
||||
type ChunkType uint8
|
||||
|
||||
type Chunk struct {
|
||||
Type string `json:"type"`
|
||||
Text any `json:"text,omitempty"`
|
||||
Type ChunkType
|
||||
Data any
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
@@ -46,63 +60,10 @@ func NewStream(w http.ResponseWriter, ctx context.Context) (*Stream, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Stream) Send(ch Chunk) error {
|
||||
debugIf(ch.Type == "error", "error: %v", ch.Text)
|
||||
|
||||
return WriteChunk(s.wr, s.ctx, ch)
|
||||
}
|
||||
|
||||
func StartChunk() Chunk {
|
||||
return Chunk{
|
||||
Type: "start",
|
||||
}
|
||||
}
|
||||
|
||||
func IDChunk(id string) Chunk {
|
||||
return Chunk{
|
||||
Type: "id",
|
||||
Text: id,
|
||||
}
|
||||
}
|
||||
|
||||
func ReasoningChunk(text string) Chunk {
|
||||
return Chunk{
|
||||
Type: "reason",
|
||||
Text: text,
|
||||
}
|
||||
}
|
||||
|
||||
func TextChunk(text string) Chunk {
|
||||
return Chunk{
|
||||
Type: "text",
|
||||
Text: CleanChunk(text),
|
||||
}
|
||||
}
|
||||
|
||||
func ImageChunk(image string) Chunk {
|
||||
return Chunk{
|
||||
Type: "image",
|
||||
Text: image,
|
||||
}
|
||||
}
|
||||
|
||||
func ToolChunk(tool *ToolCall) Chunk {
|
||||
return Chunk{
|
||||
Type: "tool",
|
||||
Text: tool,
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorChunk(err error) Chunk {
|
||||
return Chunk{
|
||||
Type: "error",
|
||||
Text: GetErrorMessage(err),
|
||||
}
|
||||
}
|
||||
|
||||
func EndChunk() Chunk {
|
||||
return Chunk{
|
||||
Type: "end",
|
||||
func NewChunk(typ ChunkType, data any) *Chunk {
|
||||
return &Chunk{
|
||||
Type: typ,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,32 +75,43 @@ func GetErrorMessage(err error) string {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func WriteChunk(w http.ResponseWriter, ctx context.Context, chunk any) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
func (s *Stream) WriteChunk(chunk *Chunk) error {
|
||||
debugIf(chunk.Type == ChunkError, "error: %v", chunk.Data)
|
||||
|
||||
if err := s.ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := GetFreeBuffer()
|
||||
defer pool.Put(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(chunk); err != nil {
|
||||
binary.Write(buf, binary.LittleEndian, chunk.Type)
|
||||
|
||||
if chunk.Data != nil {
|
||||
data, err := msgpack.Marshal(chunk.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
binary.Write(buf, binary.LittleEndian, uint32(len(data)))
|
||||
|
||||
buf.Write(data)
|
||||
} else {
|
||||
binary.Write(buf, binary.LittleEndian, uint32(0))
|
||||
}
|
||||
|
||||
if _, err := s.wr.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf.Write([]byte("\n\n"))
|
||||
|
||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
flusher, ok := s.wr.(http.Flusher)
|
||||
if !ok {
|
||||
return errors.New("failed to create flusher")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-s.ctx.Done():
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
flusher.Flush()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user