mirror of
https://github.com/coalaura/whiskr.git
synced 2025-09-09 09:19:54 +00:00
better search tools
This commit is contained in:
51
chat.go
51
chat.go
@@ -17,6 +17,7 @@ type ToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Args string `json:"args"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Done bool `json:"done,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
@@ -98,9 +99,11 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if model.Tools && r.Search {
|
||||
request.Tools = GetSearchTool()
|
||||
if model.Tools && r.Search && ExaToken != "" {
|
||||
request.Tools = GetSearchTools()
|
||||
request.ToolChoice = "auto"
|
||||
|
||||
request.Messages = append(request.Messages, openrouter.SystemMessage("You have access to web search tools. Use `search_web` with `query` (string) and `num_results` (1-10) to find current information and get result summaries. Use `fetch_contents` with `urls` (array) to read full page content. Always specify all parameters for each tool call. Call only one tool per response."))
|
||||
}
|
||||
|
||||
prompt, err := BuildPrompt(r.Prompt, model)
|
||||
@@ -148,6 +151,8 @@ func (r *Request) Parse() (*openrouter.ChatCompletionRequest, error) {
|
||||
}
|
||||
|
||||
func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
debug("new chat")
|
||||
|
||||
var raw Request
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&raw); err != nil {
|
||||
@@ -169,9 +174,6 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
request.Stream = true
|
||||
|
||||
// DEBUG
|
||||
dump(request)
|
||||
|
||||
response, err := NewStream(w)
|
||||
if err != nil {
|
||||
RespondJson(w, http.StatusBadRequest, map[string]any{
|
||||
@@ -181,12 +183,20 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
debug("handling request")
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
for iteration := range MaxIterations {
|
||||
debug("iteration %d of %d", iteration+1, MaxIterations)
|
||||
|
||||
if iteration == MaxIterations-1 {
|
||||
debug("no more tool calls")
|
||||
|
||||
request.Tools = nil
|
||||
request.ToolChoice = ""
|
||||
|
||||
request.Messages = append(request.Messages, openrouter.SystemMessage("You have reached the maximum number of tool calls for this conversation. Provide your final response based on the information you have gathered."))
|
||||
}
|
||||
|
||||
tool, message, err := RunCompletion(ctx, response, request)
|
||||
@@ -196,19 +206,39 @@ func HandleChat(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if tool == nil || tool.Name != "search_internet" {
|
||||
if tool == nil {
|
||||
debug("no tool call, done")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
debug("got %q tool call", tool.Name)
|
||||
|
||||
response.Send(ToolChunk(tool))
|
||||
|
||||
err = HandleSearchTool(ctx, tool)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
switch tool.Name {
|
||||
case "search_web":
|
||||
err = HandleSearchWebTool(ctx, tool)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
|
||||
return
|
||||
}
|
||||
case "fetch_contents":
|
||||
err = HandleFetchContentsTool(ctx, tool)
|
||||
if err != nil {
|
||||
response.Send(ErrorChunk(err))
|
||||
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
tool.Done = true
|
||||
|
||||
debug("finished tool call")
|
||||
|
||||
response.Send(ToolChunk(tool))
|
||||
|
||||
request.Messages = append(request.Messages,
|
||||
@@ -260,9 +290,6 @@ func RunCompletion(ctx context.Context, response *Stream, request *openrouter.Ch
|
||||
|
||||
choice := chunk.Choices[0]
|
||||
|
||||
// DEBUG
|
||||
debug(choice)
|
||||
|
||||
if choice.FinishReason == openrouter.FinishReasonContentFilter {
|
||||
response.Send(ErrorChunk(errors.New("stopped due to content_filter")))
|
||||
|
||||
|
Reference in New Issue
Block a user