mattermost-community-enterp.../vendor/github.com/mattermost/mattermost-plugin-ai/public/bridgeclient/completion.go

201 lines
6.2 KiB
Go

// Copyright (c) 2023-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package bridgeclient
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/mattermost/mattermost-plugin-ai/llm"
)
// AgentCompletion makes a non-streaming completion request to a specific agent by Bot ID.
// The agent parameter should be the Mattermost Bot User ID (an immutable identifier).
func (c *Client) AgentCompletion(agent string, request CompletionRequest) (string, error) {
url := fmt.Sprintf("/%s/bridge/v1/completion/agent/%s/nostream", aiPluginID, agent)
return c.doCompletionRequest(url, request)
}
// ServiceCompletion makes a non-streaming completion request to a specific service.
// The service parameter can be either a service ID or name (e.g., "openai", "anthropic").
func (c *Client) ServiceCompletion(service string, request CompletionRequest) (string, error) {
url := fmt.Sprintf("/%s/bridge/v1/completion/service/%s/nostream", aiPluginID, service)
return c.doCompletionRequest(url, request)
}
// AgentCompletionStream makes a streaming completion request to a specific agent by Bot ID.
// The agent parameter should be the Mattermost Bot User ID (an immutable identifier).
// Returns a TextStreamResult with a Stream channel for processing events.
func (c *Client) AgentCompletionStream(agent string, request CompletionRequest) (*llm.TextStreamResult, error) {
url := fmt.Sprintf("/%s/bridge/v1/completion/agent/%s", aiPluginID, agent)
return c.doStreamingRequest(url, request)
}
// ServiceCompletionStream makes a streaming completion request to a specific service.
// The service parameter can be either a service ID or name (e.g., "openai", "anthropic").
// Returns a TextStreamResult with a Stream channel for processing events.
func (c *Client) ServiceCompletionStream(service string, request CompletionRequest) (*llm.TextStreamResult, error) {
url := fmt.Sprintf("/%s/bridge/v1/completion/service/%s", aiPluginID, service)
return c.doStreamingRequest(url, request)
}
// doCompletionRequest performs a non-streaming completion request
func (c *Client) doCompletionRequest(url string, request CompletionRequest) (string, error) {
// Marshal the request body
body, err := json.Marshal(request)
if err != nil {
return "", fmt.Errorf("failed to marshal request: %w", err)
}
// Create the HTTP request
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
// Make the request
resp, err := c.httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
// Check for error status codes
if resp.StatusCode != http.StatusOK {
var errResp ErrorResponse
if err := json.Unmarshal(respBody, &errResp); err != nil {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody))
}
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, errResp.Error)
}
// Parse the success response
var completionResp CompletionResponse
if err := json.Unmarshal(respBody, &completionResp); err != nil {
return "", fmt.Errorf("failed to unmarshal response: %w", err)
}
return completionResp.Completion, nil
}
// doStreamingRequest performs a streaming completion request and returns a TextStreamResult
func (c *Client) doStreamingRequest(url string, request CompletionRequest) (*llm.TextStreamResult, error) {
// Marshal the request body
body, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Create the HTTP request
req, err := http.NewRequest("POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
// Make the request
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
// Ensure body is closed in all paths
bodyClosed := false
defer func() {
if !bodyClosed {
resp.Body.Close()
}
}()
// Check for error status codes
if resp.StatusCode != http.StatusOK {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("request failed with status %d", resp.StatusCode)
}
var errResp ErrorResponse
if err := json.Unmarshal(respBody, &errResp); err != nil {
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(respBody))
}
return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, errResp.Error)
}
// Create a channel for the stream
stream := make(chan llm.TextStreamEvent)
// Start a goroutine to read the SSE stream and populate the channel
go func() {
defer resp.Body.Close()
defer close(stream)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
// SSE lines start with "data: "
if !strings.HasPrefix(line, "data: ") {
continue
}
// Extract the data portion
data := strings.TrimPrefix(line, "data: ")
// Check for empty data lines
if data == "" {
continue
}
// Parse the JSON event
var event llm.TextStreamEvent
if err := json.Unmarshal([]byte(data), &event); err != nil {
// Send an error event
stream <- llm.TextStreamEvent{
Type: llm.EventTypeError,
Value: fmt.Errorf("error parsing stream event: %w", err),
}
return
}
// Send the event to the channel
stream <- event
// If this is an end or error event, stop reading
if event.Type == llm.EventTypeEnd || event.Type == llm.EventTypeError {
return
}
}
if err := scanner.Err(); err != nil {
stream <- llm.TextStreamEvent{
Type: llm.EventTypeError,
Value: fmt.Errorf("error reading stream: %w", err),
}
}
}()
// Mark body as handled by goroutine so defer doesn't close it
bodyClosed = true
return &llm.TextStreamResult{
Stream: stream,
}, nil
}