diff --git a/chat.go b/chat.go index bb682a8..026255d 100644 --- a/chat.go +++ b/chat.go @@ -435,25 +435,50 @@ func HandleChat(w http.ResponseWriter, r *http.Request) { debug("got %q tool call", tool.Name) - response.WriteChunk(NewChunk(ChunkTool, tool)) - switch tool.Name { case "search_web": - err = HandleSearchWebTool(ctx, tool) + arguments, err := ParseAndUpdateArgs[SearchWebArguments](tool) + if err != nil { + response.WriteChunk(NewChunk(ChunkError, err)) + + return + } + + response.WriteChunk(NewChunk(ChunkTool, tool)) + + err = HandleSearchWebTool(ctx, tool, arguments) if err != nil { response.WriteChunk(NewChunk(ChunkError, err)) return } case "fetch_contents": - err = HandleFetchContentsTool(ctx, tool) + arguments, err := ParseAndUpdateArgs[FetchContentsArguments](tool) + if err != nil { + response.WriteChunk(NewChunk(ChunkError, err)) + + return + } + + response.WriteChunk(NewChunk(ChunkTool, tool)) + + err = HandleFetchContentsTool(ctx, tool, arguments) if err != nil { response.WriteChunk(NewChunk(ChunkError, err)) return } case "github_repository": - err = HandleGitHubRepositoryTool(ctx, tool) + arguments, err := ParseAndUpdateArgs[GitHubRepositoryArguments](tool) + if err != nil { + response.WriteChunk(NewChunk(ChunkError, err)) + + return + } + + response.WriteChunk(NewChunk(ChunkTool, tool)) + + err = HandleGitHubRepositoryTool(ctx, tool, arguments) if err != nil { response.WriteChunk(NewChunk(ChunkError, err)) diff --git a/exa.go b/exa.go index e38613e..14e4345 100644 --- a/exa.go +++ b/exa.go @@ -79,7 +79,7 @@ func RunExaRequest(req *http.Request) (*ExaResults, error) { return &result, nil } -func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, error) { +func ExaRunSearch(ctx context.Context, args *SearchWebArguments) (*ExaResults, error) { if args.NumResults <= 0 { args.NumResults = 6 } else if args.NumResults < 3 { @@ -170,7 +170,7 @@ func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, er return RunExaRequest(req) } -func ExaRunContents(ctx context.Context, args FetchContentsArguments) (*ExaResults, error) { +func ExaRunContents(ctx context.Context, args *FetchContentsArguments) (*ExaResults, error) { data := map[string]any{ "urls": args.URLs, "summary": map[string]any{}, @@ -196,7 +196,7 @@ func daysAgo(days int) string { return time.Now().Add(-time.Duration(days) * 24 * time.Hour).Format(time.DateOnly) } -func ExaGuidanceForIntent(args SearchWebArguments) string { +func ExaGuidanceForIntent(args *SearchWebArguments) string { var recency string switch args.Recency { diff --git a/github.go b/github.go index b7473e6..7b36a5a 100644 --- a/github.go +++ b/github.go @@ -140,7 +140,7 @@ func GitHubRepositoryContentsJson(ctx context.Context, owner, repo, branch strin return response, nil } -func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (string, error) { +func RepoOverview(ctx context.Context, arguments *GitHubRepositoryArguments) (string, error) { repository, err := GitHubRepositoryJson(ctx, arguments.Owner, arguments.Repo) if err != nil { return "", err @@ -155,11 +155,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str ) // fetch readme - wg.Add(1) - - go func() { - defer wg.Done() - + wg.Go(func() { readme, err := GitHubRepositoryReadmeJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch) if err != nil { log.Warnf("failed to get repository readme: %v\n", err) @@ -175,14 +171,10 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str } readmeMarkdown = markdown - }() + }) // fetch contents - wg.Add(1) - - go func() { - defer wg.Done() - + wg.Go(func() { contents, err := GitHubRepositoryContentsJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch) if err != nil { log.Warnf("failed to get repository contents: %v\n", err) @@ -215,7 +207,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str sort.Strings(directories) sort.Strings(files) - }() + }) // wait and combine results wg.Wait() diff --git a/search.go b/search.go index 623cbaf..ef04f97 100644 --- a/search.go +++ b/search.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "github.com/revrost/go-openrouter" ) @@ -120,14 +121,7 @@ func GetSearchTools() []openrouter.Tool { } } -func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error { - var arguments SearchWebArguments - - err := ParseAndUpdateArgs(tool, &arguments) - if err != nil { - return err - } - +func HandleSearchWebTool(ctx context.Context, tool *ToolCall, arguments *SearchWebArguments) error { if arguments.Query == "" { return errors.New("no search query") } @@ -152,14 +146,7 @@ func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error { return nil } -func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error { - var arguments FetchContentsArguments - - err := ParseAndUpdateArgs(tool, &arguments) - if err != nil { - return err - } - +func HandleFetchContentsTool(ctx context.Context, tool *ToolCall, arguments *FetchContentsArguments) error { if len(arguments.URLs) == 0 { return errors.New("no urls") } @@ -184,14 +171,7 @@ func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error { return nil } -func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error { - var arguments GitHubRepositoryArguments - - err := ParseAndUpdateArgs(tool, &arguments) - if err != nil { - return err - } - +func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall, arguments *GitHubRepositoryArguments) error { result, err := RepoOverview(ctx, arguments) if err != nil { tool.Result = fmt.Sprintf("error: %v", err) @@ -204,10 +184,16 @@ func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error { return nil } -func ParseAndUpdateArgs(tool *ToolCall, arguments any) error { - err := json.Unmarshal([]byte(tool.Args), arguments) +func ParseAndUpdateArgs[T any](tool *ToolCall) (*T, error) { + var arguments T + + // Some models are a bit confused by numbers so we unwrap "6" -> 6 + rgx := regexp.MustCompile(`"(\d+)"`) + tool.Args = rgx.ReplaceAllString(tool.Args, "$1") + + err := json.Unmarshal([]byte(tool.Args), &arguments) if err != nil { - return fmt.Errorf("json.unmarshal: %v", err) + return nil, fmt.Errorf("json.unmarshal: %v", err) } buf := GetFreeBuffer() @@ -216,12 +202,12 @@ func ParseAndUpdateArgs(tool *ToolCall, arguments any) error { enc := json.NewEncoder(buf) enc.SetEscapeHTML(false) - err = enc.Encode(arguments) + err = enc.Encode(&arguments) if err != nil { - return fmt.Errorf("json.marshal: %v", err) + return nil, fmt.Errorf("json.marshal: %v", err) } tool.Args = buf.String() - return nil + return &arguments, nil }