mattermost-community-enterp.../vendor/github.com/mattermost/mattermost-plugin-ai/llm/token_tracking.go

134 lines
3.7 KiB
Go

// Copyright (c) 2023-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package llm
import (
"encoding/json"
"errors"
"fmt"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
// TokenUsageLoggingWrapper wraps a LanguageModel to log token usage
type TokenUsageLoggingWrapper struct {
wrapped LanguageModel
botUsername string
tokenLogger *mlog.Logger
}
// NewTokenUsageLoggingWrapper creates a new wrapper that logs token usage
func NewTokenUsageLoggingWrapper(wrapped LanguageModel, botUsername string, tokenLogger *mlog.Logger) *TokenUsageLoggingWrapper {
return &TokenUsageLoggingWrapper{
wrapped: wrapped,
botUsername: botUsername,
tokenLogger: tokenLogger,
}
}
// CreateTokenLogger creates a dedicated logger for token usage metrics
func CreateTokenLogger() (*mlog.Logger, error) {
logger, err := mlog.NewLogger()
if err != nil {
return nil, fmt.Errorf("failed to create token logger: %w", err)
}
jsonTargetCfg := mlog.TargetCfg{
Type: "file",
Format: "json",
Levels: []mlog.Level{mlog.LvlInfo, mlog.LvlDebug},
}
jsonFileOptions := map[string]interface{}{
"filename": "logs/agents/token_usage.log",
"max_size": 100, // MB
"compress": true, // compress rotated files
}
jsonOptions, err := json.Marshal(jsonFileOptions)
if err != nil {
return nil, fmt.Errorf("failed to marshal json file options: %w", err)
}
jsonTargetCfg.Options = json.RawMessage(jsonOptions)
err = logger.ConfigureTargets(map[string]mlog.TargetCfg{
"token_usage": jsonTargetCfg,
}, nil)
if err != nil {
return nil, fmt.Errorf("failed to configure token logger targets: %w", err)
}
return logger, nil
}
// ChatCompletion intercepts the streaming response to extract and log token usage
func (w *TokenUsageLoggingWrapper) ChatCompletion(request CompletionRequest, opts ...LanguageModelOption) (*TextStreamResult, error) {
result, err := w.wrapped.ChatCompletion(request, opts...)
if err != nil {
return nil, err
}
if w.tokenLogger == nil {
return nil, errors.New("token logger is nil")
}
interceptedStream := make(chan TextStreamEvent)
go func() {
defer close(interceptedStream)
for event := range result.Stream {
if event.Type != EventTypeUsage {
interceptedStream <- event
continue
}
usage, ok := event.Value.(TokenUsage)
if !ok {
continue
}
userID := "unknown"
teamID := "unknown"
if request.Context != nil {
if request.Context.RequestingUser != nil {
userID = request.Context.RequestingUser.Id
}
if request.Context.Team != nil {
teamID = request.Context.Team.Id
}
}
w.tokenLogger.Info("Token Usage",
mlog.String("user_id", userID),
mlog.String("team_id", teamID),
mlog.String("bot_username", w.botUsername),
mlog.Int("input_tokens", usage.InputTokens),
mlog.Int("output_tokens", usage.OutputTokens),
mlog.Int("total_tokens", usage.InputTokens+usage.OutputTokens),
)
}
}()
return &TextStreamResult{Stream: interceptedStream}, nil
}
// ChatCompletionNoStream uses the streaming method internally, so token usage
// logging happens automatically when ReadAll() processes the intercepted stream
func (w *TokenUsageLoggingWrapper) ChatCompletionNoStream(request CompletionRequest, opts ...LanguageModelOption) (string, error) {
result, err := w.ChatCompletion(request, opts...)
if err != nil {
return "", err
}
return result.ReadAll()
}
// CountTokens delegates to the wrapped model
func (w *TokenUsageLoggingWrapper) CountTokens(text string) int {
return w.wrapped.CountTokens(text)
}
// InputTokenLimit delegates to the wrapped model
func (w *TokenUsageLoggingWrapper) InputTokenLimit() int {
return w.wrapped.InputTokenLimit()
}