From f7fc793d40d555247236c3e78c42463641642af9 Mon Sep 17 00:00:00 2001 From: Laura Date: Sat, 25 Oct 2025 15:03:08 +0200 Subject: [PATCH] tool message last --- chat.go | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/chat.go b/chat.go index 5fa3c38..2f07948 100644 --- a/chat.go +++ b/chat.go @@ -95,15 +95,12 @@ func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage { } } -func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { - var ( - request openrouter.ChatCompletionRequest - toolIndex int - ) +func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { + var request openrouter.ChatCompletionRequest model := GetModel(r.Model) if model == nil { - return nil, 0, fmt.Errorf("unknown model: %q", r.Model) + return nil, fmt.Errorf("unknown model: %q", r.Model) } request.Model = r.Model @@ -119,11 +116,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { request.Transforms = append(request.Transforms, env.Settings.Transformation) if r.Iterations < 1 || r.Iterations > 50 { - return nil, 0, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations) + return nil, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations) } if r.Temperature < 0 || r.Temperature > 2 { - return nil, 0, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature) + return nil, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature) } request.Temperature = float32(r.Temperature) @@ -136,7 +133,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { request.Reasoning.Effort = &r.Reasoning.Effort default: if r.Reasoning.Tokens <= 0 || r.Reasoning.Tokens > 1024*1024 { - return nil, 0, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens) + return nil, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens) } request.Reasoning.MaxTokens = &r.Reasoning.Tokens @@ -151,7 +148,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { prompt, err := BuildPrompt(r.Prompt, r.Metadata, model) if err != nil { - return nil, 0, err + return nil, err } if prompt != "" { @@ -161,13 +158,6 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { if model.Tools && r.Tools.Search && env.Tokens.Exa != "" && r.Iterations > 1 { request.Tools = GetSearchTools() request.ToolChoice = "auto" - - toolIndex = len(request.Messages) - - request.Messages = append( - request.Messages, - openrouter.SystemMessage(""), - ) } for _, message := range r.Messages { @@ -202,9 +192,9 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { for i, file := range message.Files { if len(file.Name) > 512 { - return nil, 0, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i) + return nil, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i) } else if len(file.Content) > 4*1024*1024 { - return nil, 0, fmt.Errorf("file %d is invalid (too big, max 4MB)", i) + return nil, fmt.Errorf("file %d is invalid (too big, max 4MB)", i) } lines := strings.Count(file.Content, "\n") + 1 @@ -246,7 +236,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) { } } - return &request, toolIndex, nil + return &request, nil } func HandleChat(w http.ResponseWriter, r *http.Request) { @@ -262,7 +252,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { return } - request, toolIndex, err := raw.Parse() + request, err := raw.Parse() if err != nil { RespondJson(w, http.StatusBadRequest, map[string]any{ "error": err.Error(), @@ -293,6 +283,8 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { response.WriteChunk(NewChunk(ChunkStart, nil)) + var hasToolMessage bool + if len(request.Tools) > 0 { if iteration == raw.Iterations-1 { debug("no more tool calls") @@ -311,7 +303,9 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { "remaining": total - 1, }) - request.Messages[toolIndex].Content.Text = tools.String() + request.Messages = append(request.Messages, openrouter.SystemMessage(tools.String())) + + hasToolMessage = true } dump("chat.json", request) @@ -366,6 +360,10 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { response.WriteChunk(NewChunk(ChunkTool, tool)) + if hasToolMessage { + request.Messages = request.Messages[:len(request.Messages)-1] + } + request.Messages = append(request.Messages, tool.AsAssistantToolCall(message), tool.AsToolMessage(),