mirror of
https://github.com/coalaura/whiskr.git
synced 2025-09-09 09:19:54 +00:00
fixes and dynamic prompts
This commit is contained in:
100
prompts.go
100
prompts.go
@@ -2,8 +2,13 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
)
|
||||
@@ -14,35 +19,84 @@ type PromptData struct {
|
||||
Date string
|
||||
}
|
||||
|
||||
type Prompt struct {
|
||||
Key string `json:"key"`
|
||||
Name string `json:"name"`
|
||||
|
||||
Text string `json:"-"`
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed prompts/normal.txt
|
||||
PromptNormal string
|
||||
|
||||
//go:embed prompts/reviewer.txt
|
||||
PromptReviewer string
|
||||
|
||||
//go:embed prompts/engineer.txt
|
||||
PromptEngineer string
|
||||
|
||||
//go:embed prompts/scripts.txt
|
||||
PromptScripts string
|
||||
|
||||
//go:embed prompts/physics.txt
|
||||
PromptPhysics string
|
||||
|
||||
Templates = map[string]*template.Template{
|
||||
"normal": NewTemplate("normal", PromptNormal),
|
||||
"reviewer": NewTemplate("reviewer", PromptReviewer),
|
||||
"engineer": NewTemplate("engineer", PromptEngineer),
|
||||
"scripts": NewTemplate("scripts", PromptScripts),
|
||||
"physics": NewTemplate("physics", PromptPhysics),
|
||||
}
|
||||
Prompts []Prompt
|
||||
Templates = make(map[string]*template.Template)
|
||||
)
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
|
||||
Prompts, err = LoadPrompts()
|
||||
log.MustPanic(err)
|
||||
}
|
||||
|
||||
func NewTemplate(name, text string) *template.Template {
|
||||
return template.Must(template.New(name).Parse(text))
|
||||
}
|
||||
|
||||
func LoadPrompts() ([]Prompt, error) {
|
||||
var prompts []Prompt
|
||||
|
||||
log.Info("Loading prompts...")
|
||||
|
||||
err := filepath.Walk("prompts", func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil || info.IsDir() {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer file.Close()
|
||||
|
||||
body, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
index := bytes.Index(body, []byte("---"))
|
||||
if index == -1 {
|
||||
log.Warningf("Invalid prompt file: %q\n", path)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
prompt := Prompt{
|
||||
Key: strings.Replace(filepath.Base(path), ".txt", "", 1),
|
||||
Name: strings.TrimSpace(string(body[:index])),
|
||||
Text: strings.TrimSpace(string(body[:index+3])),
|
||||
}
|
||||
|
||||
prompts = append(prompts, prompt)
|
||||
|
||||
Templates[prompt.Key] = NewTemplate(prompt.Key, prompt.Text)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sort.Slice(prompts, func(i, j int) bool {
|
||||
return prompts[i].Name < prompts[j].Name
|
||||
})
|
||||
|
||||
log.Infof("Loaded %d prompts\n", len(prompts))
|
||||
|
||||
return prompts, nil
|
||||
}
|
||||
|
||||
func BuildPrompt(name string, model *Model) (string, error) {
|
||||
if name == "" {
|
||||
return "", nil
|
||||
|
Reference in New Issue
Block a user