diff --git a/main.go b/main.go index 67864f8..1436979 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,9 @@ func main() { models, err := LoadModels() log.MustFail(err) + tokenizer, err := LoadTokenizer(TikTokenSource) + log.MustFail(err) + log.Println("Preparing router...") r := chi.NewRouter() @@ -51,6 +54,8 @@ func main() { gr.Get("/-/stats/{id}", HandleStats) gr.Post("/-/title", HandleTitle) gr.Post("/-/chat", HandleChat) + + gr.Post("/-/tokenize", HandleTokenize(tokenizer)) }) log.Println("Listening at http://localhost:3443/") @@ -73,6 +78,8 @@ func cache(next http.Handler) http.Handler { } func LoadIcons() ([]string, error) { + log.Println("Loading icons...") + var icons []string directory := filepath.Join("static", "css", "icons") @@ -98,5 +105,7 @@ func LoadIcons() ([]string, error) { return nil, err } + log.Printf("Loaded %d icons\n", len(icons)) + return icons, nil } diff --git a/static/css/chat.css b/static/css/chat.css index f727eb8..6b412d1 100644 --- a/static/css/chat.css +++ b/static/css/chat.css @@ -833,6 +833,8 @@ body:not(.loading) #loading { padding-right: 14px; border-radius: 6px; border: 1px solid #363a4f; + overflow: hidden; + min-width: 140px; } .files .file .name { @@ -847,14 +849,32 @@ body:not(.loading) #loading { flex-shrink: 0; } +.files .file .tokens { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + display: flex; + align-items: center; + justify-content: center; + backdrop-filter: blur(4px); + font-size: 13px; + pointer-events: none; + opacity: 0; + transition: 150ms; +} + +.files .file:hover .tokens { + opacity: 1; +} + .files .file button.remove { content: ""; position: absolute; background-image: url(icons/remove.svg); - width: 16px; - height: 16px; - top: 1px; - right: 1px; + top: 0; + right: 0; opacity: 0; transition: 150ms; } @@ -1115,6 +1135,11 @@ label[for="reasoning-tokens"] { background-image: url(icons/attach.svg); } +#upload.loading { + animation: rotating 1.2s linear infinite; + background-image: url(icons/spinner.svg); +} + #json, #search, #scrolling, diff --git a/static/js/chat.js b/static/js/chat.js index 0f4613a..fb3d4d3 100644 --- a/static/js/chat.js +++ b/static/js/chat.js @@ -71,6 +71,7 @@ activeMessage = null, isResizing = false, scrollResize = false, + isUploading = false, totalCost = 0; function updateTotalCost() { @@ -1476,6 +1477,31 @@ }); } + async function resolveTokenCount(str) { + try { + const response = await fetch("/-/tokenize", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + string: str, + }), + }), + data = await response.json(); + + if (!response.ok) { + throw new Error(data?.error || response.statusText); + } + + return data.tokens; + } catch (err) { + console.error(err); + } + + return false; + } + let attachments = []; function buildFileElement(file, callback) { @@ -1490,6 +1516,15 @@ _file.appendChild(_name); + // token count + if ("tokens" in file) { + const _tokens = make("div", "tokens"); + + _tokens.textContent = `~${new Intl.NumberFormat("en-US").format(file.tokens)} tokens`; + + _file.appendChild(_tokens); + } + // remove button const _remove = make("button", "remove"); @@ -1718,6 +1753,10 @@ }); $upload.addEventListener("click", async () => { + if (isUploading) { + return; + } + const files = await selectFile( // the ultimate list "text/*", @@ -1744,9 +1783,29 @@ return; } + isUploading = true; + + $upload.classList.add("loading"); + + const promises = []; + + for (const file of files) { + promises.push( + resolveTokenCount(file.content).then(tokens => { + file.tokens = tokens; + }) + ); + } + + await Promise.all(promises); + for (const file of files) { pushAttachment(file); } + + $upload.classList.remove("loading"); + + isUploading = false; }); $add.addEventListener("click", () => { diff --git a/tiktoken.go b/tiktoken.go new file mode 100644 index 0000000..92e4754 --- /dev/null +++ b/tiktoken.go @@ -0,0 +1,150 @@ +package main + +import ( + "bufio" + "encoding/base64" + "errors" + "fmt" + "net/http" + "strconv" + "strings" +) + +const TikTokenSource = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken" + +type TreeNode struct { + TokenID int + Children map[byte]*TreeNode +} + +type Tokenizer struct { + Root *TreeNode +} + +func NewTreeNode() *TreeNode { + return &TreeNode{ + TokenID: -1, + Children: make(map[byte]*TreeNode), + } +} + +func (n *TreeNode) Insert(token []byte, id int) { + curr := n + + for _, b := range token { + if _, ok := curr.Children[b]; !ok { + curr.Children[b] = NewTreeNode() + } + + curr = curr.Children[b] + } + + curr.TokenID = id +} + +func LoadTokenizer(url string) (*Tokenizer, error) { + log.Println("Loading tokenizer...") + + vocabulary, err := LoadVocabulary(url) + if err != nil { + return nil, err + } + + root := NewTreeNode() + + for tokenStr, id := range vocabulary { + root.Insert([]byte(tokenStr), id) + } + + return &Tokenizer{ + Root: root, + }, nil +} + +func (t *Tokenizer) Encode(text string) []int { + var ( + index int + tokens []int + ) + + input := []byte(text) + + for index < len(input) { + bestMatchLength := 0 + bestMatchID := -1 + + currNode := t.Root + + for i := index; i < len(input); i++ { + b := input[i] + + childNode, exists := currNode.Children[b] + if !exists { + break + } + + currNode = childNode + + if currNode.TokenID != -1 { + bestMatchID = currNode.TokenID + bestMatchLength = (i - index) + 1 + } + } + + // should not be possible + if bestMatchLength == 0 { + bestMatchLength = 1 + } + + if bestMatchID != -1 { + tokens = append(tokens, bestMatchID) + } + + index += bestMatchLength + } + + return tokens +} + +func LoadVocabulary(url string) (map[string]int, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, errors.New(resp.Status) + } + + vocab := make(map[string]int) + + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + parts := strings.SplitN(scanner.Text(), " ", 2) + + if len(parts) != 2 { + continue + } + + decoded, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return nil, fmt.Errorf("failed to decode token '%s': %w", parts[0], err) + } + + id, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to parse token ID '%s': %w", parts[1], err) + } + + vocab[string(decoded)] = id + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return vocab, nil +} diff --git a/tokenize.go b/tokenize.go new file mode 100644 index 0000000..66d6519 --- /dev/null +++ b/tokenize.go @@ -0,0 +1,32 @@ +package main + +import ( + "encoding/json" + "net/http" +) + +type TokenizeRequest struct { + String string `json:"string"` +} + +func HandleTokenize(tokenizer *Tokenizer) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + debug("parsing tokenize") + + var raw TokenizeRequest + + if err := json.NewDecoder(r.Body).Decode(&raw); err != nil { + RespondJson(w, http.StatusBadRequest, map[string]any{ + "error": err.Error(), + }) + + return + } + + tokens := tokenizer.Encode(raw.String) + + RespondJson(w, http.StatusOK, map[string]any{ + "tokens": len(tokens), + }) + } +}