1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-12-02 20:22:52 +00:00

better tool prompt, fixes

This commit is contained in:
2025-09-17 01:53:02 +02:00
parent 386cfe3362
commit 22bba895e6
4 changed files with 58 additions and 26 deletions

66
chat.go
View File

@@ -1,6 +1,7 @@
package main
import (
"bytes"
"context"
"encoding/json"
"errors"
@@ -13,13 +14,13 @@ import (
)
type ToolCall struct {
ID string `json:"id"`
Name string `json:"name"`
Args string `json:"args"`
Result string `json:"result,omitempty"`
Done bool `json:"done,omitempty"`
Invalid bool `json:"invalid,omitempty"`
Cost float64 `json:"cost,omitempty"`
ID string `msgpack:"id"`
Name string `msgpack:"name"`
Args string `msgpack:"args"`
Result string `msgpack:"result,omitempty"`
Done bool `msgpack:"done,omitempty"`
Invalid bool `msgpack:"invalid,omitempty"`
Cost float64 `msgpack:"cost,omitempty"`
}
type TextFile struct {
@@ -94,12 +95,15 @@ func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage {
}
}
func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
var request openrouter.ChatCompletionRequest
func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) {
var (
request openrouter.ChatCompletionRequest
toolIndex int
)
model, ok := ModelMap[r.Model]
if !ok {
return nil, fmt.Errorf("unknown model: %q", r.Model)
return nil, 0, fmt.Errorf("unknown model: %q", r.Model)
}
request.Model = r.Model
@@ -113,11 +117,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
}
if r.Iterations < 1 || r.Iterations > 50 {
return nil, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations)
return nil, 0, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations)
}
if r.Temperature < 0 || r.Temperature > 2 {
return nil, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature)
return nil, 0, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature)
}
request.Temperature = float32(r.Temperature)
@@ -130,7 +134,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
request.Reasoning.Effort = &r.Reasoning.Effort
default:
if r.Reasoning.Tokens <= 0 || r.Reasoning.Tokens > 1024*1024 {
return nil, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens)
return nil, 0, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens)
}
request.Reasoning.MaxTokens = &r.Reasoning.Tokens
@@ -145,7 +149,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
prompt, err := BuildPrompt(r.Prompt, r.Metadata, model)
if err != nil {
return nil, err
return nil, 0, err
}
if prompt != "" {
@@ -156,9 +160,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
request.Tools = GetSearchTools()
request.ToolChoice = "auto"
toolIndex = len(request.Messages)
request.Messages = append(
request.Messages,
openrouter.SystemMessage(fmt.Sprintf(InternalToolsPrompt, r.Iterations-1)),
openrouter.SystemMessage(""),
)
}
@@ -194,9 +200,9 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
for i, file := range message.Files {
if len(file.Name) > 512 {
return nil, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i)
return nil, 0, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i)
} else if len(file.Content) > 4*1024*1024 {
return nil, fmt.Errorf("file %d is invalid (too big, max 4MB)", i)
return nil, 0, fmt.Errorf("file %d is invalid (too big, max 4MB)", i)
}
lines := strings.Count(file.Content, "\n") + 1
@@ -238,7 +244,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
}
}
return &request, nil
return &request, toolIndex, nil
}
func HandleChat(w http.ResponseWriter, r *http.Request) {
@@ -254,7 +260,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
return
}
request, err := raw.Parse()
request, toolIndex, err := raw.Parse()
if err != nil {
RespondJson(w, http.StatusBadRequest, map[string]any{
"error": err.Error(),
@@ -285,13 +291,25 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
response.WriteChunk(NewChunk(ChunkStart, nil))
if len(request.Tools) > 0 && iteration == raw.Iterations-1 {
debug("no more tool calls")
if len(request.Tools) > 0 {
if iteration == raw.Iterations-1 {
debug("no more tool calls")
request.Tools = nil
request.ToolChoice = ""
request.Tools = nil
request.ToolChoice = ""
}
request.Messages = append(request.Messages, openrouter.SystemMessage("You have reached the maximum number of tool calls for this conversation. Provide your final response based on the information you have gathered."))
// iterations - 1
total := raw.Iterations - (iteration + 1)
var tools bytes.Buffer
InternalToolsTmpl.Execute(&tools, map[string]any{
"total": total,
"remaining": total - 1,
})
request.Messages[toolIndex].Content.Text = tools.String()
}
dump("chat.json", request)