mirror of
https://github.com/coalaura/whiskr.git
synced 2025-12-02 20:22:52 +00:00
better tool prompt, fixes
This commit is contained in:
66
chat.go
66
chat.go
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -13,13 +14,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ToolCall struct {
|
type ToolCall struct {
|
||||||
ID string `json:"id"`
|
ID string `msgpack:"id"`
|
||||||
Name string `json:"name"`
|
Name string `msgpack:"name"`
|
||||||
Args string `json:"args"`
|
Args string `msgpack:"args"`
|
||||||
Result string `json:"result,omitempty"`
|
Result string `msgpack:"result,omitempty"`
|
||||||
Done bool `json:"done,omitempty"`
|
Done bool `msgpack:"done,omitempty"`
|
||||||
Invalid bool `json:"invalid,omitempty"`
|
Invalid bool `msgpack:"invalid,omitempty"`
|
||||||
Cost float64 `json:"cost,omitempty"`
|
Cost float64 `msgpack:"cost,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextFile struct {
|
type TextFile struct {
|
||||||
@@ -94,12 +95,15 @@ func (t *ToolCall) AsToolMessage() openrouter.ChatCompletionMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
func (r *Request) Parse() (*openrouter.ChatCompletionRequest, int, error) {
|
||||||
var request openrouter.ChatCompletionRequest
|
var (
|
||||||
|
request openrouter.ChatCompletionRequest
|
||||||
|
toolIndex int
|
||||||
|
)
|
||||||
|
|
||||||
model, ok := ModelMap[r.Model]
|
model, ok := ModelMap[r.Model]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unknown model: %q", r.Model)
|
return nil, 0, fmt.Errorf("unknown model: %q", r.Model)
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Model = r.Model
|
request.Model = r.Model
|
||||||
@@ -113,11 +117,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.Iterations < 1 || r.Iterations > 50 {
|
if r.Iterations < 1 || r.Iterations > 50 {
|
||||||
return nil, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations)
|
return nil, 0, fmt.Errorf("invalid iterations (1-50): %d", r.Iterations)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Temperature < 0 || r.Temperature > 2 {
|
if r.Temperature < 0 || r.Temperature > 2 {
|
||||||
return nil, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature)
|
return nil, 0, fmt.Errorf("invalid temperature (0-2): %f", r.Temperature)
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Temperature = float32(r.Temperature)
|
request.Temperature = float32(r.Temperature)
|
||||||
@@ -130,7 +134,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
|||||||
request.Reasoning.Effort = &r.Reasoning.Effort
|
request.Reasoning.Effort = &r.Reasoning.Effort
|
||||||
default:
|
default:
|
||||||
if r.Reasoning.Tokens <= 0 || r.Reasoning.Tokens > 1024*1024 {
|
if r.Reasoning.Tokens <= 0 || r.Reasoning.Tokens > 1024*1024 {
|
||||||
return nil, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens)
|
return nil, 0, fmt.Errorf("invalid reasoning tokens (1-1048576): %d", r.Reasoning.Tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
request.Reasoning.MaxTokens = &r.Reasoning.Tokens
|
request.Reasoning.MaxTokens = &r.Reasoning.Tokens
|
||||||
@@ -145,7 +149,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
|||||||
|
|
||||||
prompt, err := BuildPrompt(r.Prompt, r.Metadata, model)
|
prompt, err := BuildPrompt(r.Prompt, r.Metadata, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if prompt != "" {
|
if prompt != "" {
|
||||||
@@ -156,9 +160,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
|||||||
request.Tools = GetSearchTools()
|
request.Tools = GetSearchTools()
|
||||||
request.ToolChoice = "auto"
|
request.ToolChoice = "auto"
|
||||||
|
|
||||||
|
toolIndex = len(request.Messages)
|
||||||
|
|
||||||
request.Messages = append(
|
request.Messages = append(
|
||||||
request.Messages,
|
request.Messages,
|
||||||
openrouter.SystemMessage(fmt.Sprintf(InternalToolsPrompt, r.Iterations-1)),
|
openrouter.SystemMessage(""),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -194,9 +200,9 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
|||||||
|
|
||||||
for i, file := range message.Files {
|
for i, file := range message.Files {
|
||||||
if len(file.Name) > 512 {
|
if len(file.Name) > 512 {
|
||||||
return nil, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i)
|
return nil, 0, fmt.Errorf("file %d is invalid (name too long, max 512 characters)", i)
|
||||||
} else if len(file.Content) > 4*1024*1024 {
|
} else if len(file.Content) > 4*1024*1024 {
|
||||||
return nil, fmt.Errorf("file %d is invalid (too big, max 4MB)", i)
|
return nil, 0, fmt.Errorf("file %d is invalid (too big, max 4MB)", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Count(file.Content, "\n") + 1
|
lines := strings.Count(file.Content, "\n") + 1
|
||||||
@@ -238,7 +244,7 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &request, nil
|
return &request, toolIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleChat(w http.ResponseWriter, r *http.Request) {
|
func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -254,7 +260,7 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
request, err := raw.Parse()
|
request, toolIndex, err := raw.Parse()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
RespondJson(w, http.StatusBadRequest, map[string]any{
|
RespondJson(w, http.StatusBadRequest, map[string]any{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
@@ -285,13 +291,25 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
response.WriteChunk(NewChunk(ChunkStart, nil))
|
response.WriteChunk(NewChunk(ChunkStart, nil))
|
||||||
|
|
||||||
if len(request.Tools) > 0 && iteration == raw.Iterations-1 {
|
if len(request.Tools) > 0 {
|
||||||
debug("no more tool calls")
|
if iteration == raw.Iterations-1 {
|
||||||
|
debug("no more tool calls")
|
||||||
|
|
||||||
request.Tools = nil
|
request.Tools = nil
|
||||||
request.ToolChoice = ""
|
request.ToolChoice = ""
|
||||||
|
}
|
||||||
|
|
||||||
request.Messages = append(request.Messages, openrouter.SystemMessage("You have reached the maximum number of tool calls for this conversation. Provide your final response based on the information you have gathered."))
|
// iterations - 1
|
||||||
|
total := raw.Iterations - (iteration + 1)
|
||||||
|
|
||||||
|
var tools bytes.Buffer
|
||||||
|
|
||||||
|
InternalToolsTmpl.Execute(&tools, map[string]any{
|
||||||
|
"total": total,
|
||||||
|
"remaining": total - 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
request.Messages[toolIndex].Content.Text = tools.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
dump("chat.json", request)
|
dump("chat.json", request)
|
||||||
|
|||||||
@@ -1,5 +1,13 @@
|
|||||||
# Tool Use
|
# Tool Use
|
||||||
Use at most 1 tool call per turn. You have %d turns with tool calls total. Choose your tool strategically - you may only get one chance.
|
{{if eq .total 0}}
|
||||||
|
No more tool calls available. Provide your final response now based on the information you have gathered.
|
||||||
|
{{else if eq .total 1}}
|
||||||
|
You have exactly 1 tool call available. After using it, you must provide your final response in the next turn based on the information you have gathered.
|
||||||
|
{{else}}
|
||||||
|
You have {{.total}} tool calls remaining. If you use a tool now, you will have {{.remaining}} more tool call(s) available in subsequent turns.
|
||||||
|
{{end}}
|
||||||
|
{{- if gt .total 0}}
|
||||||
|
Use at most 1 tool call per turn (not multiple tools in the same turn).{{- if gt .total 1 }} After calling a tool, you will receive its results and can then decide whether to call another tool or provide your final response.{{- end }}
|
||||||
|
|
||||||
**search_web({query, num_results?, intent?, recency?, domains?})**
|
**search_web({query, num_results?, intent?, recency?, domains?})**
|
||||||
- Search the live web via Exa AI. Craft concise, specific queries (2-8 words typically work best).
|
- Search the live web via Exa AI. Craft concise, specific queries (2-8 words typically work best).
|
||||||
@@ -28,4 +36,5 @@ Use at most 1 tool call per turn. You have %d turns with tool calls total. Choos
|
|||||||
- Get comprehensive repo overview via GitHub API: description, README content, file structure
|
- Get comprehensive repo overview via GitHub API: description, README content, file structure
|
||||||
- Returns top-level files/directories with raw content links for direct access
|
- Returns top-level files/directories with raw content links for direct access
|
||||||
- Use when you need to understand project structure, setup instructions, or codebase overview
|
- Use when you need to understand project structure, setup instructions, or codebase overview
|
||||||
- More efficient than searching for "repo_name GitHub" when you know the exact owner/repo
|
- More efficient than searching for "repo_name GitHub" when you know the exact owner/repo
|
||||||
|
{{end}}
|
||||||
@@ -32,6 +32,8 @@ var (
|
|||||||
//go:embed internal/tools.txt
|
//go:embed internal/tools.txt
|
||||||
InternalToolsPrompt string
|
InternalToolsPrompt string
|
||||||
|
|
||||||
|
InternalToolsTmpl *template.Template
|
||||||
|
|
||||||
//go:embed internal/title.txt
|
//go:embed internal/title.txt
|
||||||
InternalTitlePrompt string
|
InternalTitlePrompt string
|
||||||
|
|
||||||
@@ -42,6 +44,7 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
InternalToolsTmpl = NewTemplate("internal-tools", InternalToolsPrompt)
|
||||||
InternalTitleTmpl = NewTemplate("internal-title", InternalTitlePrompt)
|
InternalTitleTmpl = NewTemplate("internal-title", InternalTitlePrompt)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|||||||
@@ -1191,6 +1191,8 @@
|
|||||||
chunk => {
|
chunk => {
|
||||||
stopLoadingTimeout();
|
stopLoadingTimeout();
|
||||||
|
|
||||||
|
console.log("chunk", chunk);
|
||||||
|
|
||||||
if (chunk === "aborted") {
|
if (chunk === "aborted") {
|
||||||
chatController = null;
|
chatController = null;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user