mirror of
https://github.com/coalaura/whiskr.git
synced 2025-12-02 20:22:52 +00:00
improve tool argument parsing
This commit is contained in:
35
chat.go
35
chat.go
@@ -435,25 +435,50 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
debug("got %q tool call", tool.Name)
|
||||
|
||||
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||
|
||||
switch tool.Name {
|
||||
case "search_web":
|
||||
err = HandleSearchWebTool(ctx, tool)
|
||||
arguments, err := ParseAndUpdateArgs[SearchWebArguments](tool)
|
||||
if err != nil {
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||
|
||||
err = HandleSearchWebTool(ctx, tool, arguments)
|
||||
if err != nil {
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
case "fetch_contents":
|
||||
err = HandleFetchContentsTool(ctx, tool)
|
||||
arguments, err := ParseAndUpdateArgs[FetchContentsArguments](tool)
|
||||
if err != nil {
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||
|
||||
err = HandleFetchContentsTool(ctx, tool, arguments)
|
||||
if err != nil {
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
case "github_repository":
|
||||
err = HandleGitHubRepositoryTool(ctx, tool)
|
||||
arguments, err := ParseAndUpdateArgs[GitHubRepositoryArguments](tool)
|
||||
if err != nil {
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||
|
||||
err = HandleGitHubRepositoryTool(ctx, tool, arguments)
|
||||
if err != nil {
|
||||
response.WriteChunk(NewChunk(ChunkError, err))
|
||||
|
||||
|
||||
6
exa.go
6
exa.go
@@ -79,7 +79,7 @@ func RunExaRequest(req *http.Request) (*ExaResults, error) {
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, error) {
|
||||
func ExaRunSearch(ctx context.Context, args *SearchWebArguments) (*ExaResults, error) {
|
||||
if args.NumResults <= 0 {
|
||||
args.NumResults = 6
|
||||
} else if args.NumResults < 3 {
|
||||
@@ -170,7 +170,7 @@ func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, er
|
||||
return RunExaRequest(req)
|
||||
}
|
||||
|
||||
func ExaRunContents(ctx context.Context, args FetchContentsArguments) (*ExaResults, error) {
|
||||
func ExaRunContents(ctx context.Context, args *FetchContentsArguments) (*ExaResults, error) {
|
||||
data := map[string]any{
|
||||
"urls": args.URLs,
|
||||
"summary": map[string]any{},
|
||||
@@ -196,7 +196,7 @@ func daysAgo(days int) string {
|
||||
return time.Now().Add(-time.Duration(days) * 24 * time.Hour).Format(time.DateOnly)
|
||||
}
|
||||
|
||||
func ExaGuidanceForIntent(args SearchWebArguments) string {
|
||||
func ExaGuidanceForIntent(args *SearchWebArguments) string {
|
||||
var recency string
|
||||
|
||||
switch args.Recency {
|
||||
|
||||
18
github.go
18
github.go
@@ -140,7 +140,7 @@ func GitHubRepositoryContentsJson(ctx context.Context, owner, repo, branch strin
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (string, error) {
|
||||
func RepoOverview(ctx context.Context, arguments *GitHubRepositoryArguments) (string, error) {
|
||||
repository, err := GitHubRepositoryJson(ctx, arguments.Owner, arguments.Repo)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -155,11 +155,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
|
||||
)
|
||||
|
||||
// fetch readme
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Go(func() {
|
||||
readme, err := GitHubRepositoryReadmeJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get repository readme: %v\n", err)
|
||||
@@ -175,14 +171,10 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
|
||||
}
|
||||
|
||||
readmeMarkdown = markdown
|
||||
}()
|
||||
})
|
||||
|
||||
// fetch contents
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Go(func() {
|
||||
contents, err := GitHubRepositoryContentsJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get repository contents: %v\n", err)
|
||||
@@ -215,7 +207,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
|
||||
|
||||
sort.Strings(directories)
|
||||
sort.Strings(files)
|
||||
}()
|
||||
})
|
||||
|
||||
// wait and combine results
|
||||
wg.Wait()
|
||||
|
||||
46
search.go
46
search.go
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/revrost/go-openrouter"
|
||||
)
|
||||
@@ -120,14 +121,7 @@ func GetSearchTools() []openrouter.Tool {
|
||||
}
|
||||
}
|
||||
|
||||
func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error {
|
||||
var arguments SearchWebArguments
|
||||
|
||||
err := ParseAndUpdateArgs(tool, &arguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func HandleSearchWebTool(ctx context.Context, tool *ToolCall, arguments *SearchWebArguments) error {
|
||||
if arguments.Query == "" {
|
||||
return errors.New("no search query")
|
||||
}
|
||||
@@ -152,14 +146,7 @@ func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error {
|
||||
var arguments FetchContentsArguments
|
||||
|
||||
err := ParseAndUpdateArgs(tool, &arguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func HandleFetchContentsTool(ctx context.Context, tool *ToolCall, arguments *FetchContentsArguments) error {
|
||||
if len(arguments.URLs) == 0 {
|
||||
return errors.New("no urls")
|
||||
}
|
||||
@@ -184,14 +171,7 @@ func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error {
|
||||
var arguments GitHubRepositoryArguments
|
||||
|
||||
err := ParseAndUpdateArgs(tool, &arguments)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall, arguments *GitHubRepositoryArguments) error {
|
||||
result, err := RepoOverview(ctx, arguments)
|
||||
if err != nil {
|
||||
tool.Result = fmt.Sprintf("error: %v", err)
|
||||
@@ -204,10 +184,16 @@ func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseAndUpdateArgs(tool *ToolCall, arguments any) error {
|
||||
err := json.Unmarshal([]byte(tool.Args), arguments)
|
||||
func ParseAndUpdateArgs[T any](tool *ToolCall) (*T, error) {
|
||||
var arguments T
|
||||
|
||||
// Some models are a bit confused by numbers so we unwrap "6" -> 6
|
||||
rgx := regexp.MustCompile(`"(\d+)"`)
|
||||
tool.Args = rgx.ReplaceAllString(tool.Args, "$1")
|
||||
|
||||
err := json.Unmarshal([]byte(tool.Args), &arguments)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json.unmarshal: %v", err)
|
||||
return nil, fmt.Errorf("json.unmarshal: %v", err)
|
||||
}
|
||||
|
||||
buf := GetFreeBuffer()
|
||||
@@ -216,12 +202,12 @@ func ParseAndUpdateArgs(tool *ToolCall, arguments any) error {
|
||||
enc := json.NewEncoder(buf)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
err = enc.Encode(arguments)
|
||||
err = enc.Encode(&arguments)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json.marshal: %v", err)
|
||||
return nil, fmt.Errorf("json.marshal: %v", err)
|
||||
}
|
||||
|
||||
tool.Args = buf.String()
|
||||
|
||||
return nil
|
||||
return &arguments, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user