diff --git a/chat.go b/chat.go index 2f07948..e2d05e8 100644 --- a/chat.go +++ b/chat.go @@ -95,6 +95,33 @@ func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage { } } +func (r *Request) AddToolPrompt(request *openrouter.ChatCompletionRequest, iteration int64) bool { + if len(request.Tools) == 0 { + return false + } + + if iteration == r.Iterations-1 { + debug("no more tool calls") + + request.Tools = nil + request.ToolChoice = "" + } + + // iterations - 1 + total := r.Iterations - (iteration + 1) + + var tools bytes.Buffer + + InternalToolsTmpl.Execute(&tools, map[string]any{ + "total": total, + "remaining": total - 1, + }) + + request.Messages = append(request.Messages, openrouter.SystemMessage(tools.String())) + + return true +} + func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { var request openrouter.ChatCompletionRequest @@ -239,20 +266,27 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) { return &request, nil } -func HandleChat(w http.ResponseWriter, r *http.Request) { - debug("parsing chat") - +func ParseChatRequest(r *http.Request) (*Request, *openrouter.ChatCompletionRequest, error) { var raw Request if err := json.NewDecoder(r.Body).Decode(&raw); err != nil { - RespondJson(w, http.StatusBadRequest, map[string]any{ - "error": err.Error(), - }) - - return + return nil, nil, err } request, err := raw.Parse() + if err != nil { + return nil, nil, err + } + + request.Stream = true + + return &raw, request, nil +} + +func HandleDump(w http.ResponseWriter, r *http.Request) { + debug("parsing dump") + + raw, request, err := ParseChatRequest(r) if err != nil { RespondJson(w, http.StatusBadRequest, map[string]any{ "error": err.Error(), @@ -261,7 +295,24 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { return } - request.Stream = true + raw.AddToolPrompt(request, 0) + + RespondJson(w, http.StatusOK, map[string]any{ + "request": request, + }) +} + +func HandleChat(w http.ResponseWriter, r *http.Request) { + debug("parsing chat") + + raw, request, err := ParseChatRequest(r) + if err != nil { + RespondJson(w, http.StatusBadRequest, map[string]any{ + "error": err.Error(), + }) + + return + } debug("preparing stream") @@ -283,30 +334,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { response.WriteChunk(NewChunk(ChunkStart, nil)) - var hasToolMessage bool - - if len(request.Tools) > 0 { - if iteration == raw.Iterations-1 { - debug("no more tool calls") - - request.Tools = nil - request.ToolChoice = "" - } - - // iterations - 1 - total := raw.Iterations - (iteration + 1) - - var tools bytes.Buffer - - InternalToolsTmpl.Execute(&tools, map[string]any{ - "total": total, - "remaining": total - 1, - }) - - request.Messages = append(request.Messages, openrouter.SystemMessage(tools.String())) - - hasToolMessage = true - } + hasToolMessage := raw.AddToolPrompt(request, iteration) dump("chat.json", request) diff --git a/main.go b/main.go index 86c148c..5e8bd18 100644 --- a/main.go +++ b/main.go @@ -60,6 +60,7 @@ func main() { gr.Get("/-/stats/{id}", HandleStats) gr.Post("/-/title", HandleTitle) gr.Post("/-/chat", HandleChat) + gr.Post("/-/dump", HandleDump) gr.Post("/-/tokenize", HandleTokenize(tokenizer)) })