mirror of
https://github.com/coalaura/whiskr.git
synced 2025-09-09 09:19:54 +00:00
better searching
This commit is contained in:
179
chat.go
179
chat.go
@@ -1,18 +1,28 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/revrost/go-openrouter"
|
||||
)
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Args string `json:"args"`
|
||||
Result string `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Text string `json:"text"`
|
||||
Role string `json:"role"`
|
||||
Text string `json:"text"`
|
||||
Tool *ToolCall `json:"tool"`
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
@@ -30,6 +40,27 @@ type Request struct {
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
func (t *ToolCall) AsToolCall() openrouter.ToolCall {
|
||||
return openrouter.ToolCall{
|
||||
ID: t.ID,
|
||||
Type: openrouter.ToolTypeFunction,
|
||||
Function: openrouter.FunctionCall{
|
||||
Name: t.Name,
|
||||
Arguments: t.Args,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage {
|
||||
return openrouter.ChatCompletionMessage{
|
||||
Role: openrouter.ChatMessageRoleTool,
|
||||
ToolCallID: t.ID,
|
||||
Content: openrouter.Content{
|
||||
Text: t.Result,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
||||
var request openrouter.ChatCompletionRequest
|
||||
|
||||
@@ -67,10 +98,9 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if r.Search {
|
||||
request.Plugins = append(request.Plugins, openrouter.ChatCompletionPlugin{
|
||||
ID: openrouter.PluginIDWeb,
|
||||
})
|
||||
if model.Tools && r.Search {
|
||||
request.Tools = GetSearchTool()
|
||||
request.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
prompt, err := BuildPrompt(r.Prompt, model)
|
||||
@@ -83,16 +113,35 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
||||
}
|
||||
|
||||
for index, message := range r.Messages {
|
||||
if message.Role != openrouter.ChatMessageRoleSystem && message.Role != openrouter.ChatMessageRoleAssistant && message.Role != openrouter.ChatMessageRoleUser {
|
||||
switch message.Role {
|
||||
case "system", "user":
|
||||
request.Messages = append(request.Messages, openrouter.ChatCompletionMessage{
|
||||
Role: message.Role,
|
||||
Content: openrouter.Content{
|
||||
Text: message.Text,
|
||||
},
|
||||
})
|
||||
case "assistant":
|
||||
msg := openrouter.ChatCompletionMessage{
|
||||
Role: openrouter.ChatMessageRoleAssistant,
|
||||
Content: openrouter.Content{
|
||||
Text: message.Text,
|
||||
},
|
||||
}
|
||||
|
||||
tool := message.Tool
|
||||
if tool != nil {
|
||||
msg.ToolCalls = []openrouter.ToolCall{tool.AsToolCall()}
|
||||
|
||||
request.Messages = append(request.Messages, msg)
|
||||
|
||||
msg = tool.AsToolMessage()
|
||||
}
|
||||
|
||||
request.Messages = append(request.Messages, msg)
|
||||
default:
|
||||
return nil, fmt.Errorf("[%d] invalid role: %q", index+1, message.Role)
|
||||
}
|
||||
|
||||
request.Messages = append(request.Messages, openrouter.ChatCompletionMessage{
|
||||
Role: message.Role,
|
||||
Content: openrouter.Content{
|
||||
Text: message.Text,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return &request, nil
|
||||
@@ -119,26 +168,10 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
request.Stream = true
|
||||
request.Usage = &openrouter.IncludeUsage{
|
||||
Include: true,
|
||||
}
|
||||
|
||||
// DEBUG
|
||||
dump(request)
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
stream, err := OpenRouterStartStream(ctx, *request)
|
||||
if err != nil {
|
||||
RespondJson(w, http.StatusBadRequest, map[string]any{
|
||||
"error": GetErrorMessage(err),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer stream.Close()
|
||||
|
||||
response, err := NewStream(w)
|
||||
if err != nil {
|
||||
RespondJson(w, http.StatusBadRequest, map[string]any{
|
||||
@@ -148,18 +181,74 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var id string
|
||||
ctx := r.Context()
|
||||
|
||||
for iteration := range MaxIterations {
|
||||
if iteration == MaxIterations-1 {
|
||||
request.Tools = nil
|
||||
request.ToolChoice = ""
|
||||
}
|
||||
|
||||
tool, message, err := RunCompletion(ctx, response, request)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if tool == nil || tool.Name != "search_internet" {
|
||||
return
|
||||
}
|
||||
|
||||
response.Send(ToolChunk(tool))
|
||||
|
||||
err = HandleSearchTool(ctx, tool)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
response.Send(ToolChunk(tool))
|
||||
|
||||
request.Messages = append(request.Messages,
|
||||
openrouter.ChatCompletionMessage{
|
||||
Role: openrouter.ChatMessageRoleAssistant,
|
||||
Content: openrouter.Content{
|
||||
Text: message,
|
||||
},
|
||||
ToolCalls: []openrouter.ToolCall{tool.AsToolCall()},
|
||||
},
|
||||
tool.AsToolMessage(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func RunCompletion(ctx context.Context, response *Stream, request *openrouter.ChatCompletionRequest) (*ToolCall, string, error) {
|
||||
stream, err := OpenRouterStartStream(ctx, *request)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
defer stream.Close()
|
||||
|
||||
var (
|
||||
id string
|
||||
result strings.Builder
|
||||
tool *ToolCall
|
||||
)
|
||||
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
break
|
||||
}
|
||||
|
||||
response.Send(ErrorChunk(err))
|
||||
log.Warning("stream error")
|
||||
log.WarningE(err)
|
||||
|
||||
return
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
@@ -180,15 +269,35 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
if choice.FinishReason == openrouter.FinishReasonContentFilter {
|
||||
response.Send(ErrorChunk(errors.New("stopped due to content_filter")))
|
||||
|
||||
return
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
calls := choice.Delta.ToolCalls
|
||||
|
||||
if len(calls) > 0 {
|
||||
call := calls[0]
|
||||
|
||||
if tool == nil {
|
||||
tool = &ToolCall{}
|
||||
}
|
||||
|
||||
tool.ID += call.ID
|
||||
tool.Name += call.Function.Name
|
||||
tool.Args += call.Function.Arguments
|
||||
} else if tool != nil {
|
||||
break
|
||||
}
|
||||
|
||||
content := choice.Delta.Content
|
||||
|
||||
if content != "" {
|
||||
result.WriteString(content)
|
||||
|
||||
response.Send(TextChunk(content))
|
||||
} else if choice.Delta.Reasoning != nil {
|
||||
response.Send(ReasoningChunk(*choice.Delta.Reasoning))
|
||||
}
|
||||
}
|
||||
|
||||
return tool, result.String(), nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user