1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-12-02 20:22:52 +00:00

improve tool argument parsing

This commit is contained in:
Laura
2025-11-30 21:47:59 +01:00
parent 33acc3f402
commit c15f3cf7d8
4 changed files with 54 additions and 51 deletions

35
chat.go
View File

@@ -435,25 +435,50 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
debug("got %q tool call", tool.Name)
response.WriteChunk(NewChunk(ChunkTool, tool))
switch tool.Name {
case "search_web":
err = HandleSearchWebTool(ctx, tool)
arguments, err := ParseAndUpdateArgs[SearchWebArguments](tool)
if err != nil {
response.WriteChunk(NewChunk(ChunkError, err))
return
}
response.WriteChunk(NewChunk(ChunkTool, tool))
err = HandleSearchWebTool(ctx, tool, arguments)
if err != nil {
response.WriteChunk(NewChunk(ChunkError, err))
return
}
case "fetch_contents":
err = HandleFetchContentsTool(ctx, tool)
arguments, err := ParseAndUpdateArgs[FetchContentsArguments](tool)
if err != nil {
response.WriteChunk(NewChunk(ChunkError, err))
return
}
response.WriteChunk(NewChunk(ChunkTool, tool))
err = HandleFetchContentsTool(ctx, tool, arguments)
if err != nil {
response.WriteChunk(NewChunk(ChunkError, err))
return
}
case "github_repository":
err = HandleGitHubRepositoryTool(ctx, tool)
arguments, err := ParseAndUpdateArgs[GitHubRepositoryArguments](tool)
if err != nil {
response.WriteChunk(NewChunk(ChunkError, err))
return
}
response.WriteChunk(NewChunk(ChunkTool, tool))
err = HandleGitHubRepositoryTool(ctx, tool, arguments)
if err != nil {
response.WriteChunk(NewChunk(ChunkError, err))

6
exa.go
View File

@@ -79,7 +79,7 @@ func RunExaRequest(req *http.Request) (*ExaResults, error) {
return &result, nil
}
func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, error) {
func ExaRunSearch(ctx context.Context, args *SearchWebArguments) (*ExaResults, error) {
if args.NumResults <= 0 {
args.NumResults = 6
} else if args.NumResults < 3 {
@@ -170,7 +170,7 @@ func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, er
return RunExaRequest(req)
}
func ExaRunContents(ctx context.Context, args FetchContentsArguments) (*ExaResults, error) {
func ExaRunContents(ctx context.Context, args *FetchContentsArguments) (*ExaResults, error) {
data := map[string]any{
"urls": args.URLs,
"summary": map[string]any{},
@@ -196,7 +196,7 @@ func daysAgo(days int) string {
return time.Now().Add(-time.Duration(days) * 24 * time.Hour).Format(time.DateOnly)
}
func ExaGuidanceForIntent(args SearchWebArguments) string {
func ExaGuidanceForIntent(args *SearchWebArguments) string {
var recency string
switch args.Recency {

View File

@@ -140,7 +140,7 @@ func GitHubRepositoryContentsJson(ctx context.Context, owner, repo, branch strin
return response, nil
}
func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (string, error) {
func RepoOverview(ctx context.Context, arguments *GitHubRepositoryArguments) (string, error) {
repository, err := GitHubRepositoryJson(ctx, arguments.Owner, arguments.Repo)
if err != nil {
return "", err
@@ -155,11 +155,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
)
// fetch readme
wg.Add(1)
go func() {
defer wg.Done()
wg.Go(func() {
readme, err := GitHubRepositoryReadmeJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
if err != nil {
log.Warnf("failed to get repository readme: %v\n", err)
@@ -175,14 +171,10 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
}
readmeMarkdown = markdown
}()
})
// fetch contents
wg.Add(1)
go func() {
defer wg.Done()
wg.Go(func() {
contents, err := GitHubRepositoryContentsJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
if err != nil {
log.Warnf("failed to get repository contents: %v\n", err)
@@ -215,7 +207,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
sort.Strings(directories)
sort.Strings(files)
}()
})
// wait and combine results
wg.Wait()

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"regexp"
"github.com/revrost/go-openrouter"
)
@@ -120,14 +121,7 @@ func GetSearchTools() []openrouter.Tool {
}
}
func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error {
var arguments SearchWebArguments
err := ParseAndUpdateArgs(tool, &arguments)
if err != nil {
return err
}
func HandleSearchWebTool(ctx context.Context, tool *ToolCall, arguments *SearchWebArguments) error {
if arguments.Query == "" {
return errors.New("no search query")
}
@@ -152,14 +146,7 @@ func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error {
return nil
}
func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error {
var arguments FetchContentsArguments
err := ParseAndUpdateArgs(tool, &arguments)
if err != nil {
return err
}
func HandleFetchContentsTool(ctx context.Context, tool *ToolCall, arguments *FetchContentsArguments) error {
if len(arguments.URLs) == 0 {
return errors.New("no urls")
}
@@ -184,14 +171,7 @@ func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error {
return nil
}
func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error {
var arguments GitHubRepositoryArguments
err := ParseAndUpdateArgs(tool, &arguments)
if err != nil {
return err
}
func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall, arguments *GitHubRepositoryArguments) error {
result, err := RepoOverview(ctx, arguments)
if err != nil {
tool.Result = fmt.Sprintf("error: %v", err)
@@ -204,10 +184,16 @@ func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error {
return nil
}
func ParseAndUpdateArgs(tool *ToolCall, arguments any) error {
err := json.Unmarshal([]byte(tool.Args), arguments)
func ParseAndUpdateArgs[T any](tool *ToolCall) (*T, error) {
var arguments T
// Some models are a bit confused by numbers so we unwrap "6" -> 6
rgx := regexp.MustCompile(`"(\d+)"`)
tool.Args = rgx.ReplaceAllString(tool.Args, "$1")
err := json.Unmarshal([]byte(tool.Args), &arguments)
if err != nil {
return fmt.Errorf("json.unmarshal: %v", err)
return nil, fmt.Errorf("json.unmarshal: %v", err)
}
buf := GetFreeBuffer()
@@ -216,12 +202,12 @@ func ParseAndUpdateArgs(tool *ToolCall, arguments any) error {
enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false)
err = enc.Encode(arguments)
err = enc.Encode(&arguments)
if err != nil {
return fmt.Errorf("json.marshal: %v", err)
return nil, fmt.Errorf("json.marshal: %v", err)
}
tool.Args = buf.String()
return nil
return &arguments, nil
}