mirror of
https://github.com/coalaura/whiskr.git
synced 2025-12-02 20:22:52 +00:00
improve tool argument parsing
This commit is contained in:
35
chat.go
35
chat.go
@@ -435,25 +435,50 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
debug("got %q tool call", tool.Name)
|
debug("got %q tool call", tool.Name)
|
||||||
|
|
||||||
response.WriteChunk(NewChunk(ChunkTool, tool))
|
|
||||||
|
|
||||||
switch tool.Name {
|
switch tool.Name {
|
||||||
case "search_web":
|
case "search_web":
|
||||||
err = HandleSearchWebTool(ctx, tool)
|
arguments, err := ParseAndUpdateArgs[SearchWebArguments](tool)
|
||||||
|
if err != nil {
|
||||||
|
response.WriteChunk(NewChunk(ChunkError, err))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||||
|
|
||||||
|
err = HandleSearchWebTool(ctx, tool, arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.WriteChunk(NewChunk(ChunkError, err))
|
response.WriteChunk(NewChunk(ChunkError, err))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "fetch_contents":
|
case "fetch_contents":
|
||||||
err = HandleFetchContentsTool(ctx, tool)
|
arguments, err := ParseAndUpdateArgs[FetchContentsArguments](tool)
|
||||||
|
if err != nil {
|
||||||
|
response.WriteChunk(NewChunk(ChunkError, err))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||||
|
|
||||||
|
err = HandleFetchContentsTool(ctx, tool, arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.WriteChunk(NewChunk(ChunkError, err))
|
response.WriteChunk(NewChunk(ChunkError, err))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case "github_repository":
|
case "github_repository":
|
||||||
err = HandleGitHubRepositoryTool(ctx, tool)
|
arguments, err := ParseAndUpdateArgs[GitHubRepositoryArguments](tool)
|
||||||
|
if err != nil {
|
||||||
|
response.WriteChunk(NewChunk(ChunkError, err))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response.WriteChunk(NewChunk(ChunkTool, tool))
|
||||||
|
|
||||||
|
err = HandleGitHubRepositoryTool(ctx, tool, arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
response.WriteChunk(NewChunk(ChunkError, err))
|
response.WriteChunk(NewChunk(ChunkError, err))
|
||||||
|
|
||||||
|
|||||||
6
exa.go
6
exa.go
@@ -79,7 +79,7 @@ func RunExaRequest(req *http.Request) (*ExaResults, error) {
|
|||||||
return &result, nil
|
return &result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, error) {
|
func ExaRunSearch(ctx context.Context, args *SearchWebArguments) (*ExaResults, error) {
|
||||||
if args.NumResults <= 0 {
|
if args.NumResults <= 0 {
|
||||||
args.NumResults = 6
|
args.NumResults = 6
|
||||||
} else if args.NumResults < 3 {
|
} else if args.NumResults < 3 {
|
||||||
@@ -170,7 +170,7 @@ func ExaRunSearch(ctx context.Context, args SearchWebArguments) (*ExaResults, er
|
|||||||
return RunExaRequest(req)
|
return RunExaRequest(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExaRunContents(ctx context.Context, args FetchContentsArguments) (*ExaResults, error) {
|
func ExaRunContents(ctx context.Context, args *FetchContentsArguments) (*ExaResults, error) {
|
||||||
data := map[string]any{
|
data := map[string]any{
|
||||||
"urls": args.URLs,
|
"urls": args.URLs,
|
||||||
"summary": map[string]any{},
|
"summary": map[string]any{},
|
||||||
@@ -196,7 +196,7 @@ func daysAgo(days int) string {
|
|||||||
return time.Now().Add(-time.Duration(days) * 24 * time.Hour).Format(time.DateOnly)
|
return time.Now().Add(-time.Duration(days) * 24 * time.Hour).Format(time.DateOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExaGuidanceForIntent(args SearchWebArguments) string {
|
func ExaGuidanceForIntent(args *SearchWebArguments) string {
|
||||||
var recency string
|
var recency string
|
||||||
|
|
||||||
switch args.Recency {
|
switch args.Recency {
|
||||||
|
|||||||
18
github.go
18
github.go
@@ -140,7 +140,7 @@ func GitHubRepositoryContentsJson(ctx context.Context, owner, repo, branch strin
|
|||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (string, error) {
|
func RepoOverview(ctx context.Context, arguments *GitHubRepositoryArguments) (string, error) {
|
||||||
repository, err := GitHubRepositoryJson(ctx, arguments.Owner, arguments.Repo)
|
repository, err := GitHubRepositoryJson(ctx, arguments.Owner, arguments.Repo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -155,11 +155,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
|
|||||||
)
|
)
|
||||||
|
|
||||||
// fetch readme
|
// fetch readme
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
readme, err := GitHubRepositoryReadmeJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
|
readme, err := GitHubRepositoryReadmeJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get repository readme: %v\n", err)
|
log.Warnf("failed to get repository readme: %v\n", err)
|
||||||
@@ -175,14 +171,10 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
|
|||||||
}
|
}
|
||||||
|
|
||||||
readmeMarkdown = markdown
|
readmeMarkdown = markdown
|
||||||
}()
|
})
|
||||||
|
|
||||||
// fetch contents
|
// fetch contents
|
||||||
wg.Add(1)
|
wg.Go(func() {
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
contents, err := GitHubRepositoryContentsJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
|
contents, err := GitHubRepositoryContentsJson(ctx, arguments.Owner, arguments.Repo, repository.DefaultBranch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get repository contents: %v\n", err)
|
log.Warnf("failed to get repository contents: %v\n", err)
|
||||||
@@ -215,7 +207,7 @@ func RepoOverview(ctx context.Context, arguments GitHubRepositoryArguments) (str
|
|||||||
|
|
||||||
sort.Strings(directories)
|
sort.Strings(directories)
|
||||||
sort.Strings(files)
|
sort.Strings(files)
|
||||||
}()
|
})
|
||||||
|
|
||||||
// wait and combine results
|
// wait and combine results
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|||||||
46
search.go
46
search.go
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
"github.com/revrost/go-openrouter"
|
"github.com/revrost/go-openrouter"
|
||||||
)
|
)
|
||||||
@@ -120,14 +121,7 @@ func GetSearchTools() []openrouter.Tool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error {
|
func HandleSearchWebTool(ctx context.Context, tool *ToolCall, arguments *SearchWebArguments) error {
|
||||||
var arguments SearchWebArguments
|
|
||||||
|
|
||||||
err := ParseAndUpdateArgs(tool, &arguments)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if arguments.Query == "" {
|
if arguments.Query == "" {
|
||||||
return errors.New("no search query")
|
return errors.New("no search query")
|
||||||
}
|
}
|
||||||
@@ -152,14 +146,7 @@ func HandleSearchWebTool(ctx context.Context, tool *ToolCall) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error {
|
func HandleFetchContentsTool(ctx context.Context, tool *ToolCall, arguments *FetchContentsArguments) error {
|
||||||
var arguments FetchContentsArguments
|
|
||||||
|
|
||||||
err := ParseAndUpdateArgs(tool, &arguments)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(arguments.URLs) == 0 {
|
if len(arguments.URLs) == 0 {
|
||||||
return errors.New("no urls")
|
return errors.New("no urls")
|
||||||
}
|
}
|
||||||
@@ -184,14 +171,7 @@ func HandleFetchContentsTool(ctx context.Context, tool *ToolCall) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error {
|
func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall, arguments *GitHubRepositoryArguments) error {
|
||||||
var arguments GitHubRepositoryArguments
|
|
||||||
|
|
||||||
err := ParseAndUpdateArgs(tool, &arguments)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := RepoOverview(ctx, arguments)
|
result, err := RepoOverview(ctx, arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tool.Result = fmt.Sprintf("error: %v", err)
|
tool.Result = fmt.Sprintf("error: %v", err)
|
||||||
@@ -204,10 +184,16 @@ func HandleGitHubRepositoryTool(ctx context.Context, tool *ToolCall) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseAndUpdateArgs(tool *ToolCall, arguments any) error {
|
func ParseAndUpdateArgs[T any](tool *ToolCall) (*T, error) {
|
||||||
err := json.Unmarshal([]byte(tool.Args), arguments)
|
var arguments T
|
||||||
|
|
||||||
|
// Some models are a bit confused by numbers so we unwrap "6" -> 6
|
||||||
|
rgx := regexp.MustCompile(`"(\d+)"`)
|
||||||
|
tool.Args = rgx.ReplaceAllString(tool.Args, "$1")
|
||||||
|
|
||||||
|
err := json.Unmarshal([]byte(tool.Args), &arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("json.unmarshal: %v", err)
|
return nil, fmt.Errorf("json.unmarshal: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
buf := GetFreeBuffer()
|
buf := GetFreeBuffer()
|
||||||
@@ -216,12 +202,12 @@ func ParseAndUpdateArgs(tool *ToolCall, arguments any) error {
|
|||||||
enc := json.NewEncoder(buf)
|
enc := json.NewEncoder(buf)
|
||||||
enc.SetEscapeHTML(false)
|
enc.SetEscapeHTML(false)
|
||||||
|
|
||||||
err = enc.Encode(arguments)
|
err = enc.Encode(&arguments)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("json.marshal: %v", err)
|
return nil, fmt.Errorf("json.marshal: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tool.Args = buf.String()
|
tool.Args = buf.String()
|
||||||
|
|
||||||
return nil
|
return &arguments, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user