1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-12-02 20:22:52 +00:00
Files
whiskr/chat.go

601 lines
13 KiB
Go
Raw Normal View History

2025-08-05 03:56:23 +02:00
package main
import (
2025-09-17 01:53:02 +02:00
"bytes"
2025-08-14 03:53:14 +02:00
"context"
2025-08-05 03:56:23 +02:00
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
2025-08-14 03:53:14 +02:00
"strings"
2025-11-03 03:32:19 +01:00
"time"
2025-08-05 03:56:23 +02:00
"github.com/revrost/go-openrouter"
)
2025-11-18 18:48:40 +01:00
type ToolReasoning struct {
Format string `msgpack:"format"`
Encrypted string `msgpack:"encrypted"`
}
2025-08-14 03:53:14 +02:00
type ToolCall struct {
2025-11-18 18:48:40 +01:00
ID string `msgpack:"id"`
Name string `msgpack:"name"`
Args string `msgpack:"args"`
Result string `msgpack:"result,omitempty"`
Done bool `msgpack:"done,omitempty"`
Invalid bool `msgpack:"invalid,omitempty"`
Cost float64 `msgpack:"cost,omitempty"`
Reasoning *ToolReasoning `msgpack:"reasoning,omitempty"`
2025-08-14 03:53:14 +02:00
}
2025-08-18 03:47:37 +02:00
type TextFile struct {
Name string `json:"name"`
Content string `json:"content"`
}
2025-08-05 03:56:23 +02:00
type Message struct {
2025-11-30 21:31:26 +01:00
Role string `json:"role"`
Text string `json:"text"`
Tool *ToolCall `json:"tool"`
Files []TextFile `json:"files"`
Images []string `json:"images"`
2025-08-05 03:56:23 +02:00
}
2025-08-10 22:32:40 +02:00
type Reasoning struct {
Effort string `json:"effort"`
Tokens int `json:"tokens"`
}
2025-08-28 16:37:48 +02:00
type Tools struct {
JSON bool `json:"json"`
Search bool `json:"search"`
}
type Metadata struct {
Timezone string `json:"timezone"`
Platform string `json:"platform"`
}
2025-08-05 03:56:23 +02:00
type Request struct {
Prompt string `json:"prompt"`
2025-08-05 03:56:23 +02:00
Model string `json:"model"`
2025-11-09 22:12:32 +01:00
Provider string `json:"provider"`
2025-08-05 03:56:23 +02:00
Temperature float64 `json:"temperature"`
2025-08-23 15:19:43 +02:00
Iterations int64 `json:"iterations"`
2025-08-28 16:37:48 +02:00
Tools Tools `json:"tools"`
2025-08-10 22:32:40 +02:00
Reasoning Reasoning `json:"reasoning"`
2025-08-28 16:37:48 +02:00
Metadata Metadata `json:"metadata"`
2025-08-05 03:56:23 +02:00
Messages []Message `json:"messages"`
}
2025-08-16 13:53:55 +02:00
func (t *ToolCall) AsAssistantToolCall(content string) openrouter.ChatCompletionMessage {
// Some models require there to be content
if content == "" {
content = " "
}
2025-11-18 18:48:40 +01:00
call := openrouter.ChatCompletionMessage{
2025-08-16 13:53:55 +02:00
Role: openrouter.ChatMessageRoleAssistant,
Content: openrouter.Content{
Text: content,
},
ToolCalls: []openrouter.ToolCall{
{
ID: t.ID,
Type: openrouter.ToolTypeFunction,
Function: openrouter.FunctionCall{
Name: t.Name,
Arguments: t.Args,
},
},
2025-08-14 03:53:14 +02:00
},
}
2025-11-18 18:48:40 +01:00
if t.Reasoning != nil {
call.ReasoningDetails = []openrouter.ChatCompletionReasoningDetails{
{
Type: openrouter.ReasoningDetailsTypeEncrypted,
Data: t.Reasoning.Encrypted,
ID: t.ID,
Format: t.Reasoning.Format,
Index: 0,
},
}
}
return call
2025-08-14 03:53:14 +02:00
}
func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage {
return openrouter.ChatCompletionMessage{
Role: openrouter.ChatMessageRoleTool,
ToolCallID: t.ID,
Content: openrouter.Content{
Text: t.Result,
},
}
}
2025-10-25 15:09:35 +02:00
func (r *Request) AddToolPrompt(request *openrouter.ChatCompletionRequest, iteration int64) bool {
if len(request.Tools) == 0 {
return false
}
if iteration == r.Iterations-1 {
debug("no more tool calls")
request.Tools = nil
request.ToolChoice = ""
}
// iterations - 1
total := r.Iterations - (iteration + 1)
var tools bytes.Buffer
InternalToolsTmpl.Execute(&tools, map[string]any{
"total": total,
"remaining": total - 1,
})
request.Messages = append(request.Messages, openrouter.SystemMessage(tools.String()))
return true
}
2025-10-25 15:03:08 +02:00
func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
var request openrouter.ChatCompletionRequest
2025-08-05 03:56:23 +02:00
model := GetModel(r.Model)
if model == nil {
2025-10-25 15:03:08 +02:00
return nil, fmt.Errorf("unknown model: %q", r.Model)
2025-08-05 03:56:23 +02:00
}
request.Model = r.Model
2025-09-11 23:25:58 +02:00
request.Modalities = []openrouter.ChatCompletionModality{
openrouter.ModalityText,
}
if env.Settings.ImageGeneration && model.Images {
request.Modalities = append(request.Modalities, openrouter.ModalityImage)
}
2025-10-03 01:30:40 +02:00
request.Transforms = append(request.Transforms, env.Settings.Transformation)
2025-08-23 15:19:43 +02:00
if r.Iterations < 1 || r.Iterations > 50 {
2025-10-25 15:03:08 +02:00
return nil, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations)
2025-08-23 15:19:43 +02:00
}
2025-08-10 22:39:18 +02:00
if r.Temperature < 0 || r.Temperature > 2 {
2025-10-25 15:03:08 +02:00
return nil, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature)
2025-08-05 03:56:23 +02:00
}
request.Temperature = float32(r.Temperature)
2025-08-10 22:32:40 +02:00
if model.Reasoning {
request.Reasoning = &openrouter.ChatCompletionReasoning{}
switch r.Reasoning.Effort {
case "high", "medium", "low":
request.Reasoning.Effort = &r.Reasoning.Effort
default:
if r.Reasoning.Tokens <= 0 || r.Reasoning.Tokens > 1024*1024 {
2025-10-25 15:03:08 +02:00
return nil, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens)
2025-08-10 22:32:40 +02:00
}
request.Reasoning.MaxTokens = &r.Reasoning.Tokens
}
}
2025-11-09 22:12:32 +01:00
switch r.Provider {
case "throughput":
request.Provider = &openrouter.ChatProvider{
Sort: openrouter.ProviderSortingThroughput,
}
case "latency":
request.Provider = &openrouter.ChatProvider{
Sort: openrouter.ProviderSortingLatency,
}
case "price":
request.Provider = &openrouter.ChatProvider{
Sort: openrouter.ProviderSortingPrice,
}
}
2025-08-28 16:37:48 +02:00
if model.JSON && r.Tools.JSON {
2025-08-11 00:15:58 +02:00
request.ResponseFormat = &openrouter.ChatCompletionResponseFormat{
Type: openrouter.ChatCompletionResponseFormatTypeJSONObject,
}
}
2025-08-28 16:37:48 +02:00
prompt, err := BuildPrompt(r.Prompt, r.Metadata, model)
if err != nil {
2025-10-25 15:03:08 +02:00
return nil, err
}
if prompt != "" {
2025-11-04 00:23:29 +01:00
prompt += "\n\n" + InternalGeneralPrompt
2025-11-06 03:59:21 +01:00
request.Messages = append(request.Messages, openrouter.SystemMessage(prompt))
}
2025-11-04 00:23:29 +01:00
2025-08-30 15:06:49 +02:00
if model.Tools && r.Tools.Search && env.Tokens.Exa != "" && r.Iterations > 1 {
2025-08-16 13:53:55 +02:00
request.Tools = GetSearchTools()
request.ToolChoice = "auto"
}
2025-08-25 22:45:03 +02:00
for _, message := range r.Messages {
message.Text = strings.ReplaceAll(message.Text, "\r", "")
2025-08-14 03:53:14 +02:00
switch message.Role {
2025-08-15 03:38:24 +02:00
case "system":
2025-08-14 03:53:14 +02:00
request.Messages = append(request.Messages, openrouter.ChatCompletionMessage{
Role: message.Role,
Content: openrouter.Content{
Text: message.Text,
},
})
2025-08-15 03:38:24 +02:00
case "user":
2025-11-11 16:24:53 +01:00
var (
content openrouter.Content
multi bool
last = -1
)
2025-08-15 03:38:24 +02:00
if model.Vision && strings.Contains(message.Text, "![") {
content.Multi = SplitImagePairs(message.Text)
2025-11-11 16:24:53 +01:00
multi = true
if content.Multi[len(content.Multi)-1].Type == openrouter.ChatMessagePartTypeText {
last = len(content.Multi) - 1
}
2025-08-15 03:38:24 +02:00
} else {
content.Text = message.Text
}
2025-08-18 03:47:37 +02:00
if len(message.Files) > 0 {
for i, file := range message.Files {
if len(file.Name) > 512 {
2025-10-25 15:03:08 +02:00
return nil, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i)
2025-08-18 03:47:37 +02:00
} else if len(file.Content) > 4*1024*1024 {
2025-10-25 15:03:08 +02:00
return nil, fmt.Errorf("file %d is invalid (too big, max 4MB)", i)
2025-08-18 03:47:37 +02:00
}
lines := strings.Count(file.Content, "\n") + 1
2025-11-11 16:24:53 +01:00
entry := fmt.Sprintf(
2025-11-07 18:45:49 +01:00
"FILE %q LINES %d\n<<CONTENT>>\n%s\n<<END>>",
file.Name,
lines,
file.Content,
)
2025-11-11 16:24:53 +01:00
if multi {
if last != -1 {
if content.Multi[last].Text != "" {
content.Multi[last].Text += "\n\n"
}
content.Multi[last].Text += entry
} else {
content.Multi = append(content.Multi, openrouter.ChatMessagePart{
Type: openrouter.ChatMessagePartTypeText,
Text: entry,
})
}
} else {
if content.Text != "" {
content.Text += "\n\n"
}
content.Text += entry
}
2025-08-18 03:47:37 +02:00
}
}
2025-08-15 03:38:24 +02:00
request.Messages = append(request.Messages, openrouter.ChatCompletionMessage{
Role: message.Role,
Content: content,
})
2025-08-14 03:53:14 +02:00
case "assistant":
msg := openrouter.ChatCompletionMessage{
Role: openrouter.ChatMessageRoleAssistant,
Content: openrouter.Content{
Text: message.Text,
},
}
2025-11-30 21:31:26 +01:00
for index, image := range message.Images {
msg.Images = append(msg.Images, openrouter.ChatCompletionImage{
Index: index,
Type: openrouter.StreamImageTypeImageURL,
ImageURL: openrouter.ChatCompletionImageURL{
URL: image,
},
})
}
2025-08-14 03:53:14 +02:00
tool := message.Tool
if tool != nil {
2025-08-16 13:53:55 +02:00
msg = tool.AsAssistantToolCall(message.Text)
2025-08-14 03:53:14 +02:00
request.Messages = append(request.Messages, msg)
msg = tool.AsToolMessage()
}
request.Messages = append(request.Messages, msg)
2025-08-05 03:56:23 +02:00
}
}
2025-10-25 15:03:08 +02:00
return &request, nil
2025-08-05 03:56:23 +02:00
}
2025-10-25 15:09:35 +02:00
func ParseChatRequest(r *http.Request) (*Request, *openrouter.ChatCompletionRequest, error) {
2025-08-05 03:56:23 +02:00
var raw Request
if err := json.NewDecoder(r.Body).Decode(&raw); err != nil {
2025-10-25 15:09:35 +02:00
return nil, nil, err
}
request, err := raw.Parse()
if err != nil {
return nil, nil, err
}
request.Stream = true
return &raw, request, nil
}
func HandleDump(w http.ResponseWriter, r *http.Request) {
debug("parsing dump")
raw, request, err := ParseChatRequest(r)
if err != nil {
2025-08-05 03:56:23 +02:00
RespondJson(w, http.StatusBadRequest, map[string]any{
"error": err.Error(),
})
return
}
2025-10-25 15:09:35 +02:00
raw.AddToolPrompt(request, 0)
RespondJson(w, http.StatusOK, map[string]any{
"request": request,
})
}
func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("parsing chat")
raw, request, err := ParseChatRequest(r)
2025-08-05 03:56:23 +02:00
if err != nil {
RespondJson(w, http.StatusBadRequest, map[string]any{
"error": err.Error(),
})
return
}
2025-08-15 03:38:24 +02:00
debug("preparing stream")
2025-08-29 19:26:55 +02:00
ctx := r.Context()
response, err := NewStream(w, ctx)
2025-08-05 03:56:23 +02:00
if err != nil {
RespondJson(w, http.StatusBadRequest, map[string]any{
2025-08-14 03:53:14 +02:00
"error": err.Error(),
2025-08-05 03:56:23 +02:00
})
return
}
2025-08-14 17:08:45 +02:00
debug("handling request")
2025-11-03 03:32:19 +01:00
go func() {
ticker := time.NewTicker(5 * time.Second)
2025-11-03 03:32:19 +01:00
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
response.WriteChunk(NewChunk(ChunkAlive, nil))
}
}
}()
2025-08-23 15:19:43 +02:00
for iteration := range raw.Iterations {
debug("iteration %d of %d", iteration+1, raw.Iterations)
2025-08-14 17:08:45 +02:00
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkStart, nil))
2025-08-31 23:41:28 +02:00
2025-10-25 15:09:35 +02:00
hasToolMessage := raw.AddToolPrompt(request, iteration)
2025-08-05 03:56:23 +02:00
2025-08-25 22:45:03 +02:00
dump("chat.json", request)
2025-08-16 13:53:55 +02:00
2025-08-14 03:53:14 +02:00
tool, message, err := RunCompletion(ctx, response, request)
if err != nil {
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkError, err))
2025-08-14 03:53:14 +02:00
return
}
2025-08-14 17:08:45 +02:00
if tool == nil {
debug("no tool call, done")
2025-08-14 03:53:14 +02:00
return
}
2025-08-14 17:08:45 +02:00
debug("got %q tool call", tool.Name)
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkTool, tool))
2025-08-14 03:53:14 +02:00
2025-08-14 17:08:45 +02:00
switch tool.Name {
case "search_web":
err = HandleSearchWebTool(ctx, tool)
if err != nil {
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkError, err))
2025-08-14 03:53:14 +02:00
2025-08-14 17:08:45 +02:00
return
}
case "fetch_contents":
err = HandleFetchContentsTool(ctx, tool)
if err != nil {
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkError, err))
2025-08-14 17:08:45 +02:00
2025-08-25 18:37:30 +02:00
return
}
case "github_repository":
err = HandleGitHubRepositoryTool(ctx, tool)
if err != nil {
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkError, err))
2025-08-25 18:37:30 +02:00
2025-08-14 17:08:45 +02:00
return
}
default:
2025-08-30 00:25:48 +02:00
tool.Invalid = true
tool.Result = "error: invalid tool call"
2025-08-14 03:53:14 +02:00
}
2025-08-14 17:08:45 +02:00
tool.Done = true
debug("finished tool call")
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkTool, tool))
2025-08-14 03:53:14 +02:00
2025-10-25 15:03:08 +02:00
if hasToolMessage {
request.Messages = request.Messages[:len(request.Messages)-1]
}
2025-08-14 03:53:14 +02:00
request.Messages = append(request.Messages,
2025-08-16 13:53:55 +02:00
tool.AsAssistantToolCall(message),
2025-08-14 03:53:14 +02:00
tool.AsToolMessage(),
)
2025-08-31 23:41:28 +02:00
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkEnd, nil))
2025-08-05 03:56:23 +02:00
}
2025-08-14 03:53:14 +02:00
}
2025-08-05 03:56:23 +02:00
2025-08-14 03:53:14 +02:00
func RunCompletion(ctx context.Context, response *Stream, request *openrouter.ChatCompletionRequest) (*ToolCall, string, error) {
stream, err := OpenRouterStartStream(ctx, *request)
if err != nil {
2025-11-01 23:19:27 +01:00
return nil, "", fmt.Errorf("stream.start: %v", err)
2025-08-14 03:53:14 +02:00
}
defer stream.Close()
var (
2025-11-11 14:37:09 +01:00
id string
open int
close int
reasoning bool
tool *ToolCall
2025-08-14 03:53:14 +02:00
)
2025-08-11 15:43:00 +02:00
2025-08-31 23:46:22 +02:00
buf := GetFreeBuffer()
defer pool.Put(buf)
2025-08-05 03:56:23 +02:00
for {
chunk, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
2025-08-14 03:53:14 +02:00
break
2025-08-05 03:56:23 +02:00
}
2025-11-01 23:19:27 +01:00
return nil, "", fmt.Errorf("stream.receive: %v", err)
2025-08-05 03:56:23 +02:00
}
2025-08-11 15:43:00 +02:00
if id == "" {
id = chunk.ID
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkID, id))
2025-08-11 15:43:00 +02:00
}
2025-08-05 03:56:23 +02:00
if len(chunk.Choices) == 0 {
continue
}
choice := chunk.Choices[0]
2025-09-12 14:34:08 +02:00
delta := choice.Delta
2025-08-05 03:56:23 +02:00
if choice.FinishReason == openrouter.FinishReasonContentFilter {
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkError, errors.New("stopped due to content_filter")))
2025-08-05 03:56:23 +02:00
2025-08-14 03:53:14 +02:00
return nil, "", nil
}
2025-09-12 14:34:08 +02:00
calls := delta.ToolCalls
2025-08-14 03:53:14 +02:00
if len(calls) > 0 {
call := calls[0]
2025-09-12 15:11:10 +02:00
if open > 0 && open == close {
continue
}
2025-08-14 03:53:14 +02:00
if tool == nil {
tool = &ToolCall{}
}
2025-09-12 15:11:10 +02:00
if call.ID != "" && !strings.HasSuffix(tool.ID, call.ID) {
tool.ID += call.ID
}
if call.Function.Name != "" && !strings.HasSuffix(tool.Name, call.Function.Name) {
tool.Name += call.Function.Name
}
2025-11-18 18:48:40 +01:00
if len(delta.ReasoningDetails) != 0 && tool.Reasoning == nil {
for _, details := range delta.ReasoningDetails {
if details.Type != openrouter.ReasoningDetailsTypeEncrypted {
continue
}
tool.Reasoning = &ToolReasoning{
Format: details.Format,
Encrypted: details.Data,
}
}
}
2025-09-12 15:11:10 +02:00
open += strings.Count(call.Function.Arguments, "{")
close += strings.Count(call.Function.Arguments, "}")
2025-08-14 03:53:14 +02:00
tool.Args += call.Function.Arguments
} else if tool != nil {
break
2025-08-05 03:56:23 +02:00
}
2025-09-12 14:34:08 +02:00
if delta.Content != "" {
buf.WriteString(delta.Content)
2025-08-05 03:56:23 +02:00
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkText, delta.Content))
2025-09-12 14:34:08 +02:00
} else if delta.Reasoning != nil {
2025-11-11 14:37:09 +01:00
if !reasoning && len(delta.ReasoningDetails) != 0 {
reasoning = true
response.WriteChunk(NewChunk(ChunkReasoningType, delta.ReasoningDetails[0].Type))
}
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkReasoning, *delta.Reasoning))
2025-09-12 14:34:08 +02:00
} else if len(delta.Images) > 0 {
for _, image := range delta.Images {
if image.Type != openrouter.StreamImageTypeImageURL {
continue
}
2025-08-14 03:53:14 +02:00
2025-09-12 21:39:17 +02:00
response.WriteChunk(NewChunk(ChunkImage, image.ImageURL.URL))
2025-09-12 14:34:08 +02:00
}
2025-08-05 03:56:23 +02:00
}
}
2025-08-14 03:53:14 +02:00
2025-08-31 23:46:22 +02:00
return tool, buf.String(), nil
2025-08-05 03:56:23 +02:00
}