1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-09-09 09:19:54 +00:00

better searching

This commit is contained in:
Laura
2025-08-14 03:53:14 +02:00
parent 8a790df2af
commit c740cd293d
14 changed files with 582 additions and 143 deletions

179
chat.go
View File

@@ -1,18 +1,28 @@
package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/revrost/go-openrouter"
)
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Args string `json:"args"`
Result string `json:"result,omitempty"`
}
type Message struct {
Role string `json:"role"`
Text string `json:"text"`
Role string `json:"role"`
Text string `json:"text"`
Tool *ToolCall `json:"tool"`
}
type Reasoning struct {
@@ -30,6 +40,27 @@ type Request struct {
Messages []Message `json:"messages"`
}
func (t *ToolCall) AsToolCall() openrouter.ToolCall {
return openrouter.ToolCall{
ID: t.ID,
Type: openrouter.ToolTypeFunction,
Function: openrouter.FunctionCall{
Name: t.Name,
Arguments: t.Args,
},
}
}
func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage {
return openrouter.ChatCompletionMessage{
Role: openrouter.ChatMessageRoleTool,
ToolCallID: t.ID,
Content: openrouter.Content{
Text: t.Result,
},
}
}
func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
var request openrouter.ChatCompletionRequest
@@ -67,10 +98,9 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
}
}
if r.Search {
request.Plugins = append(request.Plugins, openrouter.ChatCompletionPlugin{
ID: openrouter.PluginIDWeb,
})
if model.Tools && r.Search {
request.Tools = GetSearchTool()
request.ToolChoice = "auto"
}
prompt, err := BuildPrompt(r.Prompt, model)
@@ -83,16 +113,35 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
}
for index, message := range r.Messages {
if message.Role != openrouter.ChatMessageRoleSystem && message.Role != openrouter.ChatMessageRoleAssistant && message.Role != openrouter.ChatMessageRoleUser {
switch message.Role {
case "system", "user":
request.Messages = append(request.Messages, openrouter.ChatCompletionMessage{
Role: message.Role,
Content: openrouter.Content{
Text: message.Text,
},
})
case "assistant":
msg := openrouter.ChatCompletionMessage{
Role: openrouter.ChatMessageRoleAssistant,
Content: openrouter.Content{
Text: message.Text,
},
}
tool := message.Tool
if tool != nil {
msg.ToolCalls = []openrouter.ToolCall{tool.AsToolCall()}
request.Messages = append(request.Messages, msg)
msg = tool.AsToolMessage()
}
request.Messages = append(request.Messages, msg)
default:
return nil, fmt.Errorf("[%d] invalid role: %q", index+1, message.Role)
}
request.Messages = append(request.Messages, openrouter.ChatCompletionMessage{
Role: message.Role,
Content: openrouter.Content{
Text: message.Text,
},
})
}
return &request, nil
@@ -119,26 +168,10 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
}
request.Stream = true
request.Usage = &openrouter.IncludeUsage{
Include: true,
}
// DEBUG
dump(request)
ctx := r.Context()
stream, err := OpenRouterStartStream(ctx, *request)
if err != nil {
RespondJson(w, http.StatusBadRequest, map[string]any{
"error": GetErrorMessage(err),
})
return
}
defer stream.Close()
response, err := NewStream(w)
if err != nil {
RespondJson(w, http.StatusBadRequest, map[string]any{
@@ -148,18 +181,74 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
return
}
var id string
ctx := r.Context()
for iteration := range MaxIterations {
if iteration == MaxIterations-1 {
request.Tools = nil
request.ToolChoice = ""
}
tool, message, err := RunCompletion(ctx, response, request)
if err != nil {
response.Send(ErrorChunk(err))
return
}
if tool == nil || tool.Name != "search_internet" {
return
}
response.Send(ToolChunk(tool))
err = HandleSearchTool(ctx, tool)
if err != nil {
response.Send(ErrorChunk(err))
return
}
response.Send(ToolChunk(tool))
request.Messages = append(request.Messages,
openrouter.ChatCompletionMessage{
Role: openrouter.ChatMessageRoleAssistant,
Content: openrouter.Content{
Text: message,
},
ToolCalls: []openrouter.ToolCall{tool.AsToolCall()},
},
tool.AsToolMessage(),
)
}
}
func RunCompletion(ctx context.Context, response *Stream, request *openrouter.ChatCompletionRequest) (*ToolCall, string, error) {
stream, err := OpenRouterStartStream(ctx, *request)
if err != nil {
return nil, "", err
}
defer stream.Close()
var (
id string
result strings.Builder
tool *ToolCall
)
for {
chunk, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return
break
}
response.Send(ErrorChunk(err))
log.Warning("stream error")
log.WarningE(err)
return
return nil, "", err
}
if id == "" {
@@ -180,15 +269,35 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
if choice.FinishReason == openrouter.FinishReasonContentFilter {
response.Send(ErrorChunk(errors.New("stopped due to content_filter")))
return
return nil, "", nil
}
calls := choice.Delta.ToolCalls
if len(calls) > 0 {
call := calls[0]
if tool == nil {
tool = &ToolCall{}
}
tool.ID += call.ID
tool.Name += call.Function.Name
tool.Args += call.Function.Arguments
} else if tool != nil {
break
}
content := choice.Delta.Content
if content != "" {
result.WriteString(content)
response.Send(TextChunk(content))
} else if choice.Delta.Reasoning != nil {
response.Send(ReasoningChunk(*choice.Delta.Reasoning))
}
}
return tool, result.String(), nil
}