mirror of
https://github.com/coalaura/whiskr.git
synced 2025-12-02 20:22:52 +00:00
151 lines
2.4 KiB
Go
151 lines
2.4 KiB
Go
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"encoding/base64"
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"net/http"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
const TikTokenSource = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
|
||
|
|
|
||
|
|
type TreeNode struct {
|
||
|
|
TokenID int
|
||
|
|
Children map[byte]*TreeNode
|
||
|
|
}
|
||
|
|
|
||
|
|
type Tokenizer struct {
|
||
|
|
Root *TreeNode
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewTreeNode() *TreeNode {
|
||
|
|
return &TreeNode{
|
||
|
|
TokenID: -1,
|
||
|
|
Children: make(map[byte]*TreeNode),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (n *TreeNode) Insert(token []byte, id int) {
|
||
|
|
curr := n
|
||
|
|
|
||
|
|
for _, b := range token {
|
||
|
|
if _, ok := curr.Children[b]; !ok {
|
||
|
|
curr.Children[b] = NewTreeNode()
|
||
|
|
}
|
||
|
|
|
||
|
|
curr = curr.Children[b]
|
||
|
|
}
|
||
|
|
|
||
|
|
curr.TokenID = id
|
||
|
|
}
|
||
|
|
|
||
|
|
func LoadTokenizer(url string) (*Tokenizer, error) {
|
||
|
|
log.Println("Loading tokenizer...")
|
||
|
|
|
||
|
|
vocabulary, err := LoadVocabulary(url)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
root := NewTreeNode()
|
||
|
|
|
||
|
|
for tokenStr, id := range vocabulary {
|
||
|
|
root.Insert([]byte(tokenStr), id)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &Tokenizer{
|
||
|
|
Root: root,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *Tokenizer) Encode(text string) []int {
|
||
|
|
var (
|
||
|
|
index int
|
||
|
|
tokens []int
|
||
|
|
)
|
||
|
|
|
||
|
|
input := []byte(text)
|
||
|
|
|
||
|
|
for index < len(input) {
|
||
|
|
bestMatchLength := 0
|
||
|
|
bestMatchID := -1
|
||
|
|
|
||
|
|
currNode := t.Root
|
||
|
|
|
||
|
|
for i := index; i < len(input); i++ {
|
||
|
|
b := input[i]
|
||
|
|
|
||
|
|
childNode, exists := currNode.Children[b]
|
||
|
|
if !exists {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
|
||
|
|
currNode = childNode
|
||
|
|
|
||
|
|
if currNode.TokenID != -1 {
|
||
|
|
bestMatchID = currNode.TokenID
|
||
|
|
bestMatchLength = (i - index) + 1
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// should not be possible
|
||
|
|
if bestMatchLength == 0 {
|
||
|
|
bestMatchLength = 1
|
||
|
|
}
|
||
|
|
|
||
|
|
if bestMatchID != -1 {
|
||
|
|
tokens = append(tokens, bestMatchID)
|
||
|
|
}
|
||
|
|
|
||
|
|
index += bestMatchLength
|
||
|
|
}
|
||
|
|
|
||
|
|
return tokens
|
||
|
|
}
|
||
|
|
|
||
|
|
func LoadVocabulary(url string) (map[string]int, error) {
|
||
|
|
resp, err := http.Get(url)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
defer resp.Body.Close()
|
||
|
|
|
||
|
|
if resp.StatusCode != http.StatusOK {
|
||
|
|
return nil, errors.New(resp.Status)
|
||
|
|
}
|
||
|
|
|
||
|
|
vocab := make(map[string]int)
|
||
|
|
|
||
|
|
scanner := bufio.NewScanner(resp.Body)
|
||
|
|
|
||
|
|
for scanner.Scan() {
|
||
|
|
parts := strings.SplitN(scanner.Text(), " ", 2)
|
||
|
|
|
||
|
|
if len(parts) != 2 {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
|
||
|
|
decoded, err := base64.StdEncoding.DecodeString(parts[0])
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to decode token '%s': %w", parts[0], err)
|
||
|
|
}
|
||
|
|
|
||
|
|
id, err := strconv.Atoi(parts[1])
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("failed to parse token ID '%s': %w", parts[1], err)
|
||
|
|
}
|
||
|
|
|
||
|
|
vocab[string(decoded)] = id
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := scanner.Err(); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return vocab, nil
|
||
|
|
}
|