1
0
mirror of https://github.com/coalaura/whiskr.git synced 2025-12-02 20:22:52 +00:00
Files
whiskr/tiktoken.go

151 lines
2.4 KiB
Go
Raw Normal View History

2025-09-17 22:52:22 +02:00
package main
import (
"bufio"
"encoding/base64"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
)
const TikTokenSource = "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken"
type TreeNode struct {
TokenID int
Children map[byte]*TreeNode
}
type Tokenizer struct {
Root *TreeNode
}
func NewTreeNode() *TreeNode {
return &TreeNode{
TokenID: -1,
Children: make(map[byte]*TreeNode),
}
}
func (n *TreeNode) Insert(token []byte, id int) {
curr := n
for _, b := range token {
if _, ok := curr.Children[b]; !ok {
curr.Children[b] = NewTreeNode()
}
curr = curr.Children[b]
}
curr.TokenID = id
}
func LoadTokenizer(url string) (*Tokenizer, error) {
log.Println("Loading tokenizer...")
vocabulary, err := LoadVocabulary(url)
if err != nil {
return nil, err
}
root := NewTreeNode()
for tokenStr, id := range vocabulary {
root.Insert([]byte(tokenStr), id)
}
return &Tokenizer{
Root: root,
}, nil
}
func (t *Tokenizer) Encode(text string) []int {
var (
index int
tokens []int
)
input := []byte(text)
for index < len(input) {
bestMatchLength := 0
bestMatchID := -1
currNode := t.Root
for i := index; i < len(input); i++ {
b := input[i]
childNode, exists := currNode.Children[b]
if !exists {
break
}
currNode = childNode
if currNode.TokenID != -1 {
bestMatchID = currNode.TokenID
bestMatchLength = (i - index) + 1
}
}
// should not be possible
if bestMatchLength == 0 {
bestMatchLength = 1
}
if bestMatchID != -1 {
tokens = append(tokens, bestMatchID)
}
index += bestMatchLength
}
return tokens
}
func LoadVocabulary(url string) (map[string]int, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, errors.New(resp.Status)
}
vocab := make(map[string]int)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
parts := strings.SplitN(scanner.Text(), " ", 2)
if len(parts) != 2 {
continue
}
decoded, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode token '%s': %w", parts[0], err)
}
id, err := strconv.Atoi(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to parse token ID '%s': %w", parts[1], err)
}
vocab[string(decoded)] = id
}
if err := scanner.Err(); err != nil {
return nil, err
}
return vocab, nil
}