166 lines
5.0 KiB
Go
166 lines
5.0 KiB
Go
// Copyright (c) 2023-present Mattermost, Inc. All Rights Reserved.
|
|
// See LICENSE.txt for license information.
|
|
|
|
package llm
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/google/jsonschema-go/jsonschema"
|
|
)
|
|
|
|
// Tool represents a function that can be called by the language model during a conversation.
|
|
//
|
|
// Each tool has a name, description, and schema that defines its parameters. These are passed to the LLM for it to understand what capabilities it has.
|
|
// It is the Resolver function that implements the actual functionality.
|
|
//
|
|
// The Schema field should contain a JSONSchema that defines the expected structure of the tool's arguments.
|
|
// The Resolver function receives the conversation context and a way to access the parsed arguments,
|
|
// and returns either a result that will be passed to the LLM or an error.
|
|
type Tool struct {
|
|
Name string
|
|
Description string
|
|
Schema *jsonschema.Schema
|
|
Resolver ToolResolver
|
|
}
|
|
|
|
type ToolResolver func(context *Context, argsGetter ToolArgumentGetter) (string, error)
|
|
|
|
// ToolCallStatus represents the current status of a tool call
|
|
type ToolCallStatus int
|
|
|
|
const (
|
|
// ToolCallStatusPending indicates the tool is waiting for user approval/rejection
|
|
ToolCallStatusPending ToolCallStatus = iota
|
|
// ToolCallStatusAccepted indicates the user has accepted the tool call but it's not resolved yet
|
|
ToolCallStatusAccepted
|
|
// ToolCallStatusRejected indicates the user has rejected the tool call
|
|
ToolCallStatusRejected
|
|
// ToolCallStatusError indicates the tool call was accepted but errored during resolution
|
|
ToolCallStatusError
|
|
// ToolCallStatusSuccess indicates the tool call was accepted and resolved successfully
|
|
ToolCallStatusSuccess
|
|
)
|
|
|
|
// ToolCall represents a tool call. An empty result indicates that the tool has not yet been resolved.
|
|
type ToolCall struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Arguments json.RawMessage `json:"arguments"`
|
|
Result string `json:"result"`
|
|
Status ToolCallStatus `json:"status"`
|
|
}
|
|
|
|
type ToolArgumentGetter func(args any) error
|
|
|
|
// ToolAuthError represents an authentication error that occurred during tool creation
|
|
type ToolAuthError struct {
|
|
ServerName string `json:"server_name"`
|
|
AuthURL string `json:"auth_url"`
|
|
Error error `json:"error"`
|
|
}
|
|
|
|
type ToolStore struct {
|
|
tools map[string]Tool
|
|
log TraceLog
|
|
doTrace bool
|
|
authErrors []ToolAuthError
|
|
}
|
|
|
|
type TraceLog interface {
|
|
Info(message string, keyValuePairs ...any)
|
|
}
|
|
|
|
// NewJSONSchemaFromStruct creates a JSONSchema from a Go struct using generics
|
|
// It's a helper function for tool providers that currently define schemas as structs
|
|
func NewJSONSchemaFromStruct[T any]() *jsonschema.Schema {
|
|
schema, err := jsonschema.For[T](nil)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("failed to create JSON schema from struct: %v", err))
|
|
}
|
|
|
|
return schema
|
|
}
|
|
|
|
func NewNoTools() *ToolStore {
|
|
return &ToolStore{
|
|
tools: make(map[string]Tool),
|
|
log: nil,
|
|
doTrace: false,
|
|
authErrors: []ToolAuthError{},
|
|
}
|
|
}
|
|
|
|
func NewToolStore(log TraceLog, doTrace bool) *ToolStore {
|
|
return &ToolStore{
|
|
tools: make(map[string]Tool),
|
|
log: log,
|
|
doTrace: doTrace,
|
|
authErrors: []ToolAuthError{},
|
|
}
|
|
}
|
|
|
|
func (s *ToolStore) AddTools(tools []Tool) {
|
|
for _, tool := range tools {
|
|
s.tools[tool.Name] = tool
|
|
}
|
|
}
|
|
|
|
func (s *ToolStore) ResolveTool(name string, argsGetter ToolArgumentGetter, context *Context) (string, error) {
|
|
tool, ok := s.tools[name]
|
|
if !ok {
|
|
s.TraceUnknown(name, argsGetter)
|
|
return "", errors.New("unknown tool " + name)
|
|
}
|
|
results, err := tool.Resolver(context, argsGetter)
|
|
s.TraceResolved(name, argsGetter, results, err)
|
|
return results, err
|
|
}
|
|
|
|
func (s *ToolStore) GetTools() []Tool {
|
|
result := make([]Tool, 0, len(s.tools))
|
|
for _, tool := range s.tools {
|
|
result = append(result, tool)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (s *ToolStore) TraceUnknown(name string, argsGetter ToolArgumentGetter) {
|
|
if s.log != nil && s.doTrace {
|
|
args := ""
|
|
var raw json.RawMessage
|
|
if err := argsGetter(&raw); err != nil {
|
|
args = fmt.Sprintf("failed to get tool args: %v", err)
|
|
} else {
|
|
args = string(raw)
|
|
}
|
|
s.log.Info("unknown tool called", "name", name, "args", args)
|
|
}
|
|
}
|
|
|
|
func (s *ToolStore) TraceResolved(name string, argsGetter ToolArgumentGetter, result string, err error) {
|
|
if s.log != nil && s.doTrace {
|
|
args := ""
|
|
var raw json.RawMessage
|
|
if getArgsErr := argsGetter(&raw); getArgsErr != nil {
|
|
args = fmt.Sprintf("failed to get tool args: %v", getArgsErr)
|
|
} else {
|
|
args = string(raw)
|
|
}
|
|
s.log.Info("tool resolved", "name", name, "args", args, "result", result, "error", err)
|
|
}
|
|
}
|
|
|
|
// AddAuthError adds an authentication error to the tool store
|
|
func (s *ToolStore) AddAuthError(authError ToolAuthError) {
|
|
s.authErrors = append(s.authErrors, authError)
|
|
}
|
|
|
|
// GetAuthErrors returns all authentication errors collected during tool creation
|
|
func (s *ToolStore) GetAuthErrors() []ToolAuthError {
|
|
return s.authErrors
|
|
}
|