diff --git a/chat.go b/chat.go index e9a3472..4b2a8ff 100644 --- a/chat.go +++ b/chat.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "encoding/json" "errors" @@ -13,13 +14,13 @@ import ( ) type ToolCall struct { - ID string `json:"id"` - Name string `json:"name"` - Args string `json:"args"` - Result string `json:"result,omitempty"` - Done bool `json:"done,omitempty"` - Invalid bool `json:"invalid,omitempty"` - Cost float64 `json:"cost,omitempty"` + ID string `msgpack:"id"` + Name string `msgpack:"name"` + Args string `msgpack:"args"` + Result string `msgpack:"result,omitempty"` + Done bool `msgpack:"done,omitempty"` + Invalid bool `msgpack:"invalid,omitempty"` + Cost float64 `msgpack:"cost,omitempty"` } type TextFile struct { @@ -94,12 +95,15 @@ func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage { } } -func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { - var request openrouter.ChatCompletionRequest +func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { + var ( + request openrouter.ChatCompletionRequest + toolIndex int + ) model, ok := ModelMap[r.Model] if !ok { - return nil, fmt.Errorf("unknown model: %q", r.Model) + return nil, 0, fmt.Errorf("unknown model: %q", r.Model) } request.Model = r.Model @@ -113,11 +117,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { } if r.Iterations < 1 || r.Iterations > 50 { - return nil, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations) + return nil, 0, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations) } if r.Temperature < 0 || r.Temperature > 2 { - return nil, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature) + return nil, 0, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature) } request.Temperature = float32(r.Temperature) @@ -130,7 +134,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { request.Reasoning.Effort = &r.Reasoning.Effort default: if r.Reasoning.Tokens <= 0 || r.Reasoning.Tokens > 1024*1024 { - return nil, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens) + return nil, 0, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens) } request.Reasoning.MaxTokens = &r.Reasoning.Tokens @@ -145,7 +149,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { prompt, err := BuildPrompt(r.Prompt, r.Metadata, model) if err != nil { - return nil, err + return nil, 0, err } if prompt != "" { @@ -156,9 +160,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { request.Tools = GetSearchTools() request.ToolChoice = "auto" + toolIndex = len(request.Messages) + request.Messages = append( request.Messages, - openrouter.SystemMessage(fmt.Sprintf(InternalToolsPrompt, r.Iterations-1)), + openrouter.SystemMessage(""), ) } @@ -194,9 +200,9 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { for i, file := range message.Files { if len(file.Name) > 512 { - return nil, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i) + return nil, 0, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i) } else if len(file.Content) > 4*1024*1024 { - return nil, fmt.Errorf("file %d is invalid (too big, max 4MB)", i) + return nil, 0, fmt.Errorf("file %d is invalid (too big, max 4MB)", i) } lines := strings.Count(file.Content, "\n") + 1 @@ -238,7 +244,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { } } - return &request, nil + return &request, toolIndex, nil } func HandleChat(w http.ResponseWriter, r *http.Request) { @@ -254,7 +260,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { return } - request, err := raw.Parse() + request, toolIndex, err := raw.Parse() if err != nil { RespondJson(w, http.StatusBadRequest, map[string]any{ "error": err.Error(), @@ -285,13 +291,25 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { response.WriteChunk(NewChunk(ChunkStart, nil)) - if len(request.Tools) > 0 && iteration == raw.Iterations-1 { - debug("no more tool calls") + if len(request.Tools) > 0 { + if iteration == raw.Iterations-1 { + debug("no more tool calls") - request.Tools = nil - request.ToolChoice = "" + request.Tools = nil + request.ToolChoice = "" + } - request.Messages = append(request.Messages, openrouter.SystemMessage("You have reached the maximum number of tool calls for this conversation. Provide your final response based on the information you have gathered.")) + // iterations - 1 + total := raw.Iterations - (iteration + 1) + + var tools bytes.Buffer + + InternalToolsTmpl.Execute(&tools, map[string]any{ + "total": total, + "remaining": total - 1, + }) + + request.Messages[toolIndex].Content.Text = tools.String() } dump("chat.json", request) diff --git a/internal/tools.txt b/internal/tools.txt index b488c2d..d419434 100644 --- a/internal/tools.txt +++ b/internal/tools.txt @@ -1,5 +1,13 @@ # Tool Use -Use at most 1 tool call per turn. You have %d turns with tool calls total. Choose your tool strategically - you may only get one chance. +{{if eq .total 0}} +No more tool calls available. Provide your final response now based on the information you have gathered. +{{else if eq .total 1}} +You have exactly 1 tool call available. After using it, you must provide your final response in the next turn based on the information you have gathered. +{{else}} +You have {{.total}} tool calls remaining. If you use a tool now, you will have {{.remaining}} more tool call(s) available in subsequent turns. +{{end}} +{{- if gt .total 0}} +Use at most 1 tool call per turn (not multiple tools in the same turn).{{- if gt .total 1 }} After calling a tool, you will receive its results and can then decide whether to call another tool or provide your final response.{{- end }} **search_web({query, num_results?, intent?, recency?, domains?})** - Search the live web via Exa AI. Craft concise, specific queries (2-8 words typically work best). @@ -28,4 +36,5 @@ Use at most 1 tool call per turn. You have %d turns with tool calls total. Choos - Get comprehensive repo overview via GitHub API: description, README content, file structure - Returns top-level files/directories with raw content links for direct access - Use when you need to understand project structure, setup instructions, or codebase overview -- More efficient than searching for "repo_name GitHub" when you know the exact owner/repo \ No newline at end of file +- More efficient than searching for "repo_name GitHub" when you know the exact owner/repo +{{end}} \ No newline at end of file diff --git a/prompts.go b/prompts.go index 9a0cda1..243b431 100644 --- a/prompts.go +++ b/prompts.go @@ -32,6 +32,8 @@ var ( //go:embed internal/tools.txt InternalToolsPrompt string + InternalToolsTmpl *template.Template + //go:embed internal/title.txt InternalTitlePrompt string @@ -42,6 +44,7 @@ var ( ) func init() { + InternalToolsTmpl = NewTemplate("internal-tools", InternalToolsPrompt) InternalTitleTmpl = NewTemplate("internal-title", InternalTitlePrompt) var err error diff --git a/static/js/chat.js b/static/js/chat.js index a37953b..0f4613a 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -1191,6 +1191,8 @@ chunk => { stopLoadingTimeout(); + console.log("chunk", chunk); + if (chunk === "aborted") { chatController = null;