Fix: Add cmd/mattermost directory (was ignored by .gitignore)

This commit is contained in:
Claude 2025-12-18 00:11:57 +09:00
parent ec1f89217a
commit 302a074c8e
486 changed files with 128967 additions and 2 deletions

4
.gitignore vendored
View File

@ -1,5 +1,5 @@
# Binaries # Binaries (root level only)
mattermost /mattermost
*.exe *.exe
# IDE # IDE

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,456 @@
{
"__elements": [],
"__requires": [
{
"type": "grafana",
"id": "grafana",
"name": "Grafana",
"version": "8.3.4"
},
{
"type": "panel",
"id": "graph",
"name": "Graph (old)",
"version": ""
},
{
"type": "datasource",
"id": "prometheus",
"name": "Prometheus",
"version": "1.0.0"
}
],
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": "-- Grafana --",
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"target": {
"limit": 100,
"matchAny": false,
"tags": [],
"type": "dashboard"
},
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"id": null,
"iteration": 1646759397232,
"links": [],
"liveNow": false,
"panels": [
{
"aliasColors": {},
"bars": false,
"dashLength": 10,
"dashes": false,
"description": "",
"fill": 1,
"fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 0,
"y": 0
},
"hiddenSeries": false,
"id": 2,
"legend": {
"avg": false,
"current": false,
"max": false,
"min": false,
"show": true,
"total": false,
"values": false
},
"lines": true,
"linewidth": 1,
"nullPointMode": "null",
"options": {
"alertThreshold": true
},
"percentage": false,
"pluginVersion": "8.3.4",
"pointradius": 2,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
"spaceLength": 10,
"stack": false,
"steppedLine": false,
"targets": [
{
"expr": "histogram_quantile (\n 0.99,\n sum by (le,instance)(\n rate(mattermost_db_store_time_bucket{instance=~\"$server\",method=\"ThreadStore.GetThreadsForUser\"}[5m])\n )\n)",
"interval": "",
"legendFormat": "p99-{{instance}}",
"refId": "A"
},
{
"expr": "histogram_quantile (\n 0.50,\n sum by (le,instance)(\n rate(mattermost_db_store_time_bucket{instance=~\"$server\",method=\"ThreadStore.GetThreadsForUser\"}[5m])\n )\n)",
"interval": "",
"legendFormat": "p50-{{instance}}",
"refId": "B"
}
],
"thresholds": [],
"timeRegions": [],
"title": "GetThreadsForUser duration",
"tooltip": {
"shared": true,
"sort": 0,
"value_type": "individual"
},
"type": "graph",
"xaxis": {
"mode": "time",
"show": true,
"values": []
},
"yaxes": [
{
"$$hashKey": "object:98",
"format": "s",
"logBase": 1,
"show": true
},
{
"$$hashKey": "object:99",
"format": "short",
"logBase": 1,
"show": true
}
],
"yaxis": {
"align": false
}
},
{
"aliasColors": {},
"bars": false,
"dashLength": 10,
"dashes": false,
"fill": 1,
"fillGradient": 0,
"gridPos": {
"h": 9,
"w": 12,
"x": 12,
"y": 0
},
"hiddenSeries": false,
"id": 4,
"legend": {
"avg": false,
"current": false,
"max": false,
"min": false,
"show": true,
"total": false,
"values": false
},
"lines": true,
"linewidth": 1,
"nullPointMode": "null",
"options": {
"alertThreshold": true
},
"percentage": false,
"pluginVersion": "8.3.4",
"pointradius": 2,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
"spaceLength": 10,
"stack": false,
"steppedLine": false,
"targets": [
{
"expr": "sum(rate(mattermost_db_store_time_count{instance=~\"$server\",method=\"ThreadStore.GetThreadsForUser\"}[1m])) by (instance)",
"interval": "",
"legendFormat": "count-{{instance}}",
"refId": "A"
},
{
"expr": "sum(rate(mattermost_db_store_time_count{instance=~\"$server\",method=\"ThreadStore.GetThreadsForUser\"}[1m]))",
"interval": "",
"legendFormat": "Total",
"refId": "B"
}
],
"thresholds": [],
"timeRegions": [],
"title": "GetThreadsForUser Requests Per Second",
"tooltip": {
"shared": true,
"sort": 0,
"value_type": "individual"
},
"type": "graph",
"xaxis": {
"mode": "time",
"show": true,
"values": []
},
"yaxes": [
{
"$$hashKey": "object:396",
"format": "short",
"logBase": 1,
"show": true
},
{
"$$hashKey": "object:397",
"format": "short",
"logBase": 1,
"show": true
}
],
"yaxis": {
"align": false
}
},
{
"aliasColors": {},
"bars": false,
"dashLength": 10,
"dashes": false,
"fill": 1,
"fillGradient": 0,
"gridPos": {
"h": 8,
"w": 12,
"x": 0,
"y": 9
},
"hiddenSeries": false,
"id": 6,
"legend": {
"avg": false,
"current": false,
"max": false,
"min": false,
"show": true,
"total": false,
"values": false
},
"lines": true,
"linewidth": 1,
"nullPointMode": "null",
"options": {
"alertThreshold": true
},
"percentage": false,
"pluginVersion": "8.3.4",
"pointradius": 2,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
"spaceLength": 10,
"stack": false,
"steppedLine": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"exemplar": true,
"expr": "histogram_quantile (\n 0.99,\n sum by (le,instance)(\n rate(mattermost_db_store_time_bucket{instance=~\"$server\",method=\"ThreadStore.MarkAllAsReadByChannels\"}[5m])\n )\n)",
"interval": "",
"legendFormat": "p99-{{instance}}",
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"exemplar": true,
"expr": "histogram_quantile (\n 0.50,\n sum by (le,instance)(\n rate(mattermost_db_store_time_bucket{instance=~\"$server\",method=\"ThreadStore.MarkAllAsReadByChannels\"}[5m])\n )\n)",
"interval": "",
"legendFormat": "p50-{{instance}}",
"refId": "B"
}
],
"thresholds": [],
"timeRegions": [],
"title": "MarkAllAsReadByChannels duration",
"tooltip": {
"shared": true,
"sort": 0,
"value_type": "individual"
},
"type": "graph",
"xaxis": {
"mode": "time",
"show": true,
"values": []
},
"yaxes": [
{
"$$hashKey": "object:504",
"format": "s",
"logBase": 1,
"show": true
},
{
"$$hashKey": "object:505",
"format": "short",
"logBase": 1,
"show": true
}
],
"yaxis": {
"align": false
}
},
{
"aliasColors": {},
"bars": false,
"dashLength": 10,
"dashes": false,
"fill": 1,
"fillGradient": 0,
"gridPos": {
"h": 8,
"w": 12,
"x": 12,
"y": 9
},
"hiddenSeries": false,
"id": 8,
"legend": {
"avg": false,
"current": false,
"max": false,
"min": false,
"show": true,
"total": false,
"values": false
},
"lines": true,
"linewidth": 1,
"nullPointMode": "null",
"options": {
"alertThreshold": true
},
"percentage": false,
"pluginVersion": "8.3.4",
"pointradius": 2,
"points": false,
"renderer": "flot",
"seriesOverrides": [],
"spaceLength": 10,
"stack": false,
"steppedLine": false,
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"exemplar": true,
"expr": "sum(rate(mattermost_db_store_time_count{instance=~\"$server\",method=\"ThreadStore.MarkAllAsReadByChannels\"}[1m])) by (instance)",
"interval": "",
"legendFormat": "count-{{instance}}",
"refId": "A"
},
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"exemplar": true,
"expr": "sum(rate(mattermost_db_store_time_count{instance=~\"$server\",method=\"ThreadStore.MarkAllAsReadByChannels\"}[1m]))",
"interval": "",
"legendFormat": "Total",
"refId": "B"
}
],
"thresholds": [],
"timeRegions": [],
"title": "MarkAllAsReadByChannels Requests Per Second",
"tooltip": {
"shared": true,
"sort": 0,
"value_type": "individual"
},
"type": "graph",
"xaxis": {
"mode": "time",
"show": true,
"values": []
},
"yaxes": [
{
"$$hashKey": "object:714",
"format": "short",
"logBase": 1,
"show": true
},
{
"$$hashKey": "object:715",
"format": "short",
"logBase": 1,
"show": true
}
],
"yaxis": {
"align": false
}
}
],
"refresh": "10s",
"schemaVersion": 34,
"style": "dark",
"tags": [],
"templating": {
"list": [
{
"current": {},
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"definition": "label_values(instance)",
"hide": 0,
"includeAll": true,
"label": "server",
"multi": true,
"name": "server",
"options": [],
"query": {
"query": "label_values(instance)",
"refId": "Prometheus-server-Variable-Query"
},
"refresh": 1,
"regex": "",
"skipUrlSync": false,
"sort": 0,
"tagValuesQuery": "",
"tagsQuery": "",
"type": "query",
"useTags": false
}
]
},
"time": {
"from": "now-3h",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "Collapsed Reply Threads Performance",
"uid": "cZY9yFJ7z",
"version": 5,
"weekStart": ""
}

View File

@ -0,0 +1,412 @@
{
"__elements": {},
"__requires": [
{
"type": "grafana",
"id": "grafana",
"name": "Grafana",
"version": "10.4.2"
},
{
"type": "datasource",
"id": "prometheus",
"name": "Prometheus",
"version": "1.0.0"
},
{
"type": "panel",
"id": "timeseries",
"name": "Time series",
"version": ""
}
],
"annotations": {
"list": [
{
"builtIn": 1,
"datasource": {
"type": "grafana",
"uid": "-- Grafana --"
},
"enable": true,
"hide": true,
"iconColor": "rgba(0, 211, 255, 1)",
"name": "Annotations & Alerts",
"type": "dashboard"
}
]
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 0,
"id": null,
"links": [],
"panels": [
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"description": "The average amount of a machine's CPU used by the given Desktop App process over the measuring interval (usually 1 minute). ",
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 10,
"w": 24,
"x": 0,
"y": 0
},
"id": 1,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile($percentile / 100, sum by(le) (increase(mattermost_desktopapp_cpu_usage_bucket{processName=~\"$processName\",version=~\"$version\",platform=~\"$platform\"}[$decay])))",
"instant": false,
"legendFormat": "[[percentile]]th Percentile",
"range": true,
"refId": "A"
}
],
"title": "CPU Usage (%)",
"type": "timeseries"
},
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"description": "The number of megabytes used by a Desktop App process at the time of measurement.",
"fieldConfig": {
"defaults": {
"color": {
"mode": "palette-classic-by-name"
},
"custom": {
"axisBorderShow": false,
"axisCenteredZero": false,
"axisColorMode": "text",
"axisLabel": "",
"axisPlacement": "auto",
"barAlignment": 0,
"drawStyle": "line",
"fillOpacity": 0,
"gradientMode": "none",
"hideFrom": {
"legend": false,
"tooltip": false,
"viz": false
},
"insertNulls": false,
"lineInterpolation": "linear",
"lineWidth": 1,
"pointSize": 5,
"scaleDistribution": {
"type": "linear"
},
"showPoints": "auto",
"spanNulls": false,
"stacking": {
"group": "A",
"mode": "none"
},
"thresholdsStyle": {
"mode": "off"
}
},
"mappings": [],
"thresholds": {
"mode": "absolute",
"steps": [
{
"color": "green",
"value": null
},
{
"color": "red",
"value": 80
}
]
}
},
"overrides": []
},
"gridPos": {
"h": 9,
"w": 24,
"x": 0,
"y": 10
},
"id": 2,
"options": {
"legend": {
"calcs": [],
"displayMode": "list",
"placement": "bottom",
"showLegend": true
},
"tooltip": {
"mode": "single",
"sort": "none"
}
},
"targets": [
{
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"editorMode": "code",
"expr": "histogram_quantile($percentile / 100, sum by(le) (increase(mattermost_desktopapp_memory_usage_bucket{processName=~\"$processName\",version=~\"$version\",platform=~\"$platform\"}[$decay])))",
"instant": false,
"legendFormat": "[[percentile]]th Percentile",
"range": true,
"refId": "A"
}
],
"title": "Memory Usage (MBs)",
"type": "timeseries"
}
],
"refresh": "1m",
"schemaVersion": 39,
"tags": [],
"templating": {
"list": [
{
"current": {},
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"definition": "label_values(mattermost_desktopapp_cpu_usage_bucket,processName)",
"hide": 0,
"includeAll": true,
"label": "Process",
"multi": true,
"name": "processName",
"options": [],
"query": {
"qryType": 1,
"query": "label_values(mattermost_desktopapp_cpu_usage_bucket,processName)",
"refId": "PrometheusVariableQueryEditor-VariableQuery"
},
"refresh": 1,
"regex": "",
"skipUrlSync": false,
"sort": 0,
"type": "query"
},
{
"current": {},
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"definition": "label_values(mattermost_desktopapp_cpu_usage_bucket,version)",
"hide": 0,
"includeAll": true,
"label": "App Version",
"multi": true,
"name": "version",
"options": [],
"query": {
"qryType": 1,
"query": "label_values(mattermost_desktopapp_cpu_usage_bucket,version)",
"refId": "PrometheusVariableQueryEditor-VariableQuery"
},
"refresh": 1,
"regex": "",
"skipUrlSync": false,
"sort": 0,
"type": "query"
},
{
"current": {},
"datasource": {
"type": "prometheus",
"uid": "Prometheus"
},
"definition": "label_values(mattermost_desktopapp_cpu_usage_bucket,platform)",
"hide": 0,
"includeAll": true,
"multi": true,
"name": "platform",
"options": [],
"query": {
"qryType": 1,
"query": "label_values(mattermost_desktopapp_cpu_usage_bucket,platform)",
"refId": "PrometheusVariableQueryEditor-VariableQuery"
},
"refresh": 1,
"regex": "",
"skipUrlSync": false,
"sort": 0,
"type": "query"
},
{
"current": {
"selected": true,
"text": "50",
"value": "50"
},
"hide": 0,
"includeAll": false,
"label": "Percentile",
"multi": false,
"name": "percentile",
"options": [
{
"selected": true,
"text": "50",
"value": "50"
},
{
"selected": false,
"text": "75",
"value": "75"
},
{
"selected": false,
"text": "90",
"value": "90"
},
{
"selected": false,
"text": "99",
"value": "99"
}
],
"query": "50,75,90,99",
"queryValue": "",
"skipUrlSync": false,
"type": "custom"
},
{
"current": {
"selected": false,
"text": "30m",
"value": "30m"
},
"hide": 0,
"includeAll": false,
"label": "Decay Time",
"multi": false,
"name": "decay",
"options": [
{
"selected": true,
"text": "30m",
"value": "30m"
},
{
"selected": false,
"text": "1h",
"value": "1h"
},
{
"selected": false,
"text": "3h",
"value": "3h"
},
{
"selected": false,
"text": "6h",
"value": "6h"
},
{
"selected": false,
"text": "12h",
"value": "12h"
},
{
"selected": false,
"text": "1d",
"value": "1d"
}
],
"query": "30m,1h,3h,6h,12h,1d",
"queryValue": "",
"skipUrlSync": false,
"type": "custom"
}
]
},
"time": {
"from": "now-24h",
"to": "now"
},
"timepicker": {},
"timezone": "browser",
"title": "Desktop App Metrics",
"uid": "fe12lkd7062v4a",
"version": 4,
"weekStart": ""
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,226 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
"testing"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/v8/channels/api4"
"github.com/mattermost/mattermost/server/v8/channels/store/storetest/mocks"
"github.com/mattermost/mattermost/server/v8/channels/testlib"
)
var coverprofileCounters = make(map[string]int)
var mainHelper *testlib.MainHelper
type testHelper struct {
*api4.TestHelper
config *model.Config
tempDir string
configFilePath string
disableAutoConfig bool
}
// Setup creates an instance of testHelper.
func Setup(tb testing.TB) *testHelper {
dir, err := testlib.SetupTestResources()
if err != nil {
panic("failed to create temporary directory: " + err.Error())
}
api4TestHelper := api4.Setup(tb)
testHelper := &testHelper{
TestHelper: api4TestHelper,
tempDir: dir,
configFilePath: filepath.Join(dir, "config-helper.json"),
}
config := &model.Config{}
config.SetDefaults()
testHelper.SetConfig(config)
return testHelper
}
// Setup creates an instance of testHelper.
func SetupWithStoreMock(tb testing.TB) *testHelper {
dir, err := testlib.SetupTestResources()
if err != nil {
panic("failed to create temporary directory: " + err.Error())
}
api4TestHelper := api4.SetupWithStoreMock(tb)
systemStore := mocks.SystemStore{}
systemStore.On("Get").Return(make(model.StringMap), nil)
licenseStore := mocks.LicenseStore{}
licenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil)
api4TestHelper.App.Srv().Store().(*mocks.Store).On("System").Return(&systemStore)
api4TestHelper.App.Srv().Store().(*mocks.Store).On("License").Return(&licenseStore)
testHelper := &testHelper{
TestHelper: api4TestHelper,
tempDir: dir,
configFilePath: filepath.Join(dir, "config-helper.json"),
}
config := &model.Config{}
config.SetDefaults()
testHelper.SetConfig(config)
return testHelper
}
// InitBasic simply proxies to api4.InitBasic, while still returning a testHelper.
func (h *testHelper) InitBasic() *testHelper {
h.TestHelper.InitBasic()
return h
}
// TemporaryDirectory returns the temporary directory created for user by the test helper.
func (h *testHelper) TemporaryDirectory() string {
return h.tempDir
}
// Config returns the configuration passed to a running command.
func (h *testHelper) Config() *model.Config {
return h.config.Clone()
}
// ConfigPath returns the path to the temporary config file passed to a running command.
func (h *testHelper) ConfigPath() string {
return h.configFilePath
}
// SetConfig replaces the configuration passed to a running command.
func (h *testHelper) SetConfig(config *model.Config) {
if !testing.Short() {
config.SqlSettings = *mainHelper.GetSQLSettings()
}
// Disable strict password requirements for test
*config.PasswordSettings.MinimumLength = 5
*config.PasswordSettings.Lowercase = false
*config.PasswordSettings.Uppercase = false
*config.PasswordSettings.Symbol = false
*config.PasswordSettings.Number = false
h.config = config
buf, err := json.Marshal(config)
if err != nil {
panic("failed to marshal config: " + err.Error())
}
if err := os.WriteFile(h.configFilePath, buf, 0600); err != nil {
panic("failed to write file " + h.configFilePath + ": " + err.Error())
}
}
// SetAutoConfig configures whether the --config flag is automatically passed to a running command.
func (h *testHelper) SetAutoConfig(autoConfig bool) {
h.disableAutoConfig = !autoConfig
}
// TearDown cleans up temporary files and assets created during the life of the test helper.
func (h *testHelper) TearDown() {
h.TestHelper.TearDown()
os.RemoveAll(h.tempDir)
}
func (h *testHelper) execArgs(t *testing.T, args []string) []string {
ret := []string{"-test.v", "-test.run", "ExecCommand"}
if coverprofile := flag.Lookup("test.coverprofile").Value.String(); coverprofile != "" {
dir := filepath.Dir(coverprofile)
base := filepath.Base(coverprofile)
baseParts := strings.SplitN(base, ".", 2)
name := strings.Replace(t.Name(), "/", "_", -1)
coverprofileCounters[name] = coverprofileCounters[name] + 1
baseParts[0] = fmt.Sprintf("%v-%v-%v", baseParts[0], name, coverprofileCounters[name])
ret = append(ret, "-test.coverprofile", filepath.Join(dir, strings.Join(baseParts, ".")))
}
ret = append(ret, "--")
// Unless the test passes a `--config` of its own, create a temporary one from the default
// configuration with the current test database applied.
hasConfig := h.disableAutoConfig
if slices.Contains(args, "--config") {
hasConfig = true
}
if !hasConfig {
ret = append(ret, "--config", h.configFilePath)
}
ret = append(ret, args...)
return ret
}
func (h *testHelper) cmd(t *testing.T, args []string) *exec.Cmd {
path, err := os.Executable()
require.NoError(t, err)
cmd := exec.Command(path, h.execArgs(t, args)...)
cmd.Env = []string{}
for _, env := range os.Environ() {
// Ignore MM_SQLSETTINGS_DATASOURCE from the environment, since we override.
if strings.HasPrefix(env, "MM_SQLSETTINGS_DATASOURCE=") {
continue
}
cmd.Env = append(cmd.Env, env)
}
return cmd
}
// CheckCommand invokes the test binary, returning the output modified for assertion testing.
func (h *testHelper) CheckCommand(t *testing.T, args ...string) string {
output, err := h.cmd(t, args).CombinedOutput()
require.NoError(t, err, string(output))
return strings.TrimSpace(strings.TrimSuffix(strings.TrimSpace(string(output)), "PASS"))
}
// RunCommand invokes the test binary, returning only any error.
func (h *testHelper) RunCommand(t *testing.T, args ...string) error {
return h.cmd(t, args).Run()
}
// RunCommandWithOutput is a variant of RunCommand that returns the unmodified output and any error.
func (h *testHelper) RunCommandWithOutput(t *testing.T, args ...string) (string, error) {
cmd := h.cmd(t, args)
var buf bytes.Buffer
reader, writer := io.Pipe()
cmd.Stdout = writer
cmd.Stderr = writer
done := make(chan bool)
go func() {
io.Copy(&buf, reader)
close(done)
}()
err := cmd.Run()
writer.Close()
<-done
return buf.String(), err
}

View File

@ -0,0 +1,347 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"bytes"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
"github.com/mattermost/mattermost/server/v8/config"
"github.com/mattermost/mattermost/server/v8/platform/shared/filestore"
"github.com/mattermost/morph"
"github.com/mattermost/morph/models"
)
var DbCmd = &cobra.Command{
Use: "db",
Short: "Commands related to the database",
}
var InitDbCmd = &cobra.Command{
Use: "init",
Short: "Initialize the database",
Long: `Initialize the database for a given DSN, executing the migrations and loading the custom defaults if any.
This command should be run using a database configuration DSN.`,
Example: ` # you can use the config flag to pass the DSN
$ mattermost db init --config postgres://localhost/mattermost
# or you can use the MM_CONFIG environment variable
$ MM_CONFIG=postgres://localhost/mattermost mattermost db init
# and you can set a custom defaults file to be loaded into the database
$ MM_CUSTOM_DEFAULTS_PATH=custom.json MM_CONFIG=postgres://localhost/mattermost mattermost db init`,
Args: cobra.NoArgs,
RunE: initDbCmdF,
}
var ResetCmd = &cobra.Command{
Use: "reset",
Short: "Reset the database to initial state",
Long: "Completely erases the database causing the loss of all data. This will reset Mattermost to its initial state.",
RunE: resetCmdF,
}
var MigrateCmd = &cobra.Command{
Use: "migrate",
Short: "Migrate the database if there are any unapplied migrations",
Long: "Run the missing migrations from the migrations table.",
RunE: migrateCmdF,
}
var DowngradeCmd = &cobra.Command{
Use: "downgrade",
Short: "Downgrade the database with the given plan or migration numbers",
Long: "Downgrade the database with the given plan or migration numbers. " +
"The plan will be read from filestore hence the path should be relative to file store root.",
RunE: downgradeCmdF,
Args: cobra.ExactArgs(1),
}
var DBVersionCmd = &cobra.Command{
Use: "version",
Short: "Returns the recent applied version number",
RunE: dbVersionCmdF,
}
func init() {
ResetCmd.Flags().Bool("confirm", false, "Confirm you really want to delete everything and a DB backup has been performed.")
DBVersionCmd.Flags().Bool("all", false, "Returns all applied migrations")
MigrateCmd.Flags().Bool("auto-recover", false, "Recover the database to it's existing state after a failed migration.")
MigrateCmd.Flags().Bool("save-plan", false, "Saves the migration plan into file store so that it can be used in the future.")
MigrateCmd.Flags().Bool("dry-run", false, "Runs the migration plan without applying it.")
DowngradeCmd.Flags().Bool("auto-recover", false, "Recover the database to it's existing state after a failed migration.")
DowngradeCmd.Flags().Bool("dry-run", false, "Runs the migration plan without applying it.")
DbCmd.AddCommand(
InitDbCmd,
ResetCmd,
MigrateCmd,
DowngradeCmd,
DBVersionCmd,
)
RootCmd.AddCommand(
DbCmd,
)
}
func initDbCmdF(command *cobra.Command, _ []string) error {
logger := mlog.CreateConsoleLogger()
dsn := getConfigDSN(command, config.GetEnvironment())
if !config.IsDatabaseDSN(dsn) {
return errors.New("this command should be run using a database configuration DSN")
}
customDefaults, err := loadCustomDefaults()
if err != nil {
return errors.Wrap(err, "error loading custom configuration defaults")
}
configStore, err := config.NewStoreFromDSN(getConfigDSN(command, config.GetEnvironment()), false, customDefaults, true)
if err != nil {
return errors.Wrap(err, "failed to load configuration")
}
defer configStore.Close()
sqlStore, err := sqlstore.New(configStore.Get().SqlSettings, logger, nil)
if err != nil {
return errors.Wrap(err, "failed to initialize store")
}
defer sqlStore.Close()
CommandPrettyPrintln("Database store correctly initialised")
return nil
}
func resetCmdF(command *cobra.Command, args []string) error {
logger := mlog.CreateConsoleLogger()
ss, err := initStoreCommandContextCobra(logger, command)
if err != nil {
return errors.Wrap(err, "could not initialize store")
}
defer ss.Close()
confirmFlag, _ := command.Flags().GetBool("confirm")
if !confirmFlag {
var confirm string
CommandPrettyPrintln("Have you performed a database backup? (YES/NO): ")
fmt.Scanln(&confirm)
if confirm != "YES" {
return errors.New("ABORTED: You did not answer YES exactly, in all capitals.")
}
CommandPrettyPrintln("Are you sure you want to delete everything? All data will be permanently deleted? (YES/NO): ")
fmt.Scanln(&confirm)
if confirm != "YES" {
return errors.New("ABORTED: You did not answer YES exactly, in all capitals.")
}
}
ss.DropAllTables()
CommandPrettyPrintln("Database successfully reset")
return nil
}
func migrateCmdF(command *cobra.Command, args []string) error {
logger := mlog.CreateConsoleLogger()
defer logger.Shutdown()
cfgDSN := getConfigDSN(command, config.GetEnvironment())
recoverFlag, _ := command.Flags().GetBool("auto-recover")
savePlan, _ := command.Flags().GetBool("save-plan")
dryRun, _ := command.Flags().GetBool("dry-run")
cfgStore, err := config.NewStoreFromDSN(cfgDSN, true, nil, true)
if err != nil {
return errors.Wrap(err, "failed to load configuration")
}
config := cfgStore.Get()
migrator, err := sqlstore.NewMigrator(config.SqlSettings, logger, dryRun)
if err != nil {
return errors.Wrap(err, "failed to create migrator")
}
defer migrator.Close()
plan, err := migrator.GeneratePlan(recoverFlag)
if err != nil {
return errors.Wrap(err, "failed to generate migration plan")
}
if len(plan.Migrations) == 0 {
CommandPrettyPrintln("No migrations to apply.")
return nil
}
if savePlan || recoverFlag {
backend, err2 := filestore.NewFileBackend(ConfigToFileBackendSettings(&config.FileSettings, false, true))
if err2 != nil {
return fmt.Errorf("failed to initialize filebackend: %w", err2)
}
b, mErr := json.MarshalIndent(plan, "", " ")
if mErr != nil {
return fmt.Errorf("failed to marshal plan: %w", mErr)
}
fileName, err2 := migrator.GetFileName(plan)
if err2 != nil {
return fmt.Errorf("failed to generate plan file: %w", err2)
}
_, err = backend.WriteFile(bytes.NewReader(b), fileName+".json")
if err != nil {
return fmt.Errorf("failed to write migration plan: %w", err)
}
CommandPrettyPrintln(
fmt.Sprintf("%s\nThe migration plan has been saved. File: %q.\nNote that "+
" migration plan is saved into file store, so the filepath will be relative to root of file store\n%s",
strings.Repeat("*", 80), fileName+".json", strings.Repeat("*", 80)))
}
err = migrator.MigrateWithPlan(plan, dryRun)
if err != nil {
return errors.Wrap(err, "failed to migrate with the plan")
}
CommandPrettyPrintln("Database successfully migrated")
return nil
}
func downgradeCmdF(command *cobra.Command, args []string) error {
logger := mlog.CreateConsoleLogger()
defer logger.Shutdown()
cfgDSN := getConfigDSN(command, config.GetEnvironment())
cfgStore, err := config.NewStoreFromDSN(cfgDSN, true, nil, true)
if err != nil {
return errors.Wrap(err, "failed to load configuration")
}
config := cfgStore.Get()
dryRun, _ := command.Flags().GetBool("dry-run")
recoverFlag, _ := command.Flags().GetBool("auto-recover")
backend, err2 := filestore.NewFileBackend(ConfigToFileBackendSettings(&config.FileSettings, false, true))
if err2 != nil {
return fmt.Errorf("failed to initialize filebackend: %w", err2)
}
migrator, err := sqlstore.NewMigrator(config.SqlSettings, logger, dryRun)
if err != nil {
return errors.Wrap(err, "failed to create migrator")
}
defer migrator.Close()
// check if the input is version numbers or a file
// if the input is given as a file, we assume it's a migration plan
versions := strings.Split(args[0], ",")
if _, sErr := strconv.Atoi(versions[0]); sErr == nil {
CommandPrettyPrintln("Database will be downgraded with the following versions: ", versions)
err = migrator.DowngradeMigrations(dryRun, versions...)
if err != nil {
return errors.Wrap(err, "failed to downgrade migrations")
}
CommandPrettyPrintln("Database successfully downgraded")
return nil
}
b, err := backend.ReadFile(args[0])
if err != nil {
return fmt.Errorf("failed to read plan: %w", err)
}
var plan models.Plan
err = json.Unmarshal(b, &plan)
if err != nil {
return fmt.Errorf("failed to unmarshal plan: %w", err)
}
morph.SwapPlanDirection(&plan)
plan.Auto = recoverFlag
err = migrator.MigrateWithPlan(&plan, dryRun)
if err != nil {
return errors.Wrap(err, "failed to migrate with the plan")
}
CommandPrettyPrintln("Database successfully downgraded")
return nil
}
func dbVersionCmdF(command *cobra.Command, args []string) error {
logger := mlog.CreateConsoleLogger()
defer logger.Shutdown()
ss, err := initStoreCommandContextCobra(logger, command)
if err != nil {
return errors.Wrap(err, "could not initialize store")
}
defer ss.Close()
allFlag, _ := command.Flags().GetBool("all")
if allFlag {
applied, err2 := ss.GetAppliedMigrations()
if err2 != nil {
return errors.Wrap(err2, "failed to get applied migrations")
}
for _, migration := range applied {
CommandPrettyPrintln(fmt.Sprintf("Varsion: %d, Name: %s", migration.Version, migration.Name))
}
return nil
}
v, err := ss.GetDBSchemaVersion()
if err != nil {
return errors.Wrap(err, "failed to get schema version")
}
CommandPrettyPrintln("Current database schema version is: " + strconv.Itoa(v))
return nil
}
func ConfigToFileBackendSettings(s *model.FileSettings, enableComplianceFeature bool, skipVerify bool) filestore.FileBackendSettings {
if *s.DriverName == model.ImageDriverLocal {
return filestore.FileBackendSettings{
DriverName: *s.DriverName,
Directory: *s.Directory,
}
}
return filestore.FileBackendSettings{
DriverName: *s.DriverName,
AmazonS3AccessKeyId: *s.AmazonS3AccessKeyId,
AmazonS3SecretAccessKey: *s.AmazonS3SecretAccessKey,
AmazonS3Bucket: *s.AmazonS3Bucket,
AmazonS3PathPrefix: *s.AmazonS3PathPrefix,
AmazonS3Region: *s.AmazonS3Region,
AmazonS3Endpoint: *s.AmazonS3Endpoint,
AmazonS3SSL: s.AmazonS3SSL == nil || *s.AmazonS3SSL,
AmazonS3SignV2: s.AmazonS3SignV2 != nil && *s.AmazonS3SignV2,
AmazonS3SSE: s.AmazonS3SSE != nil && *s.AmazonS3SSE && enableComplianceFeature,
AmazonS3Trace: s.AmazonS3Trace != nil && *s.AmazonS3Trace,
AmazonS3RequestTimeoutMilliseconds: *s.AmazonS3RequestTimeoutMilliseconds,
SkipVerify: skipVerify,
}
}

View File

@ -0,0 +1,19 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"flag"
"testing"
"github.com/stretchr/testify/require"
)
func TestExecCommand(t *testing.T) {
if filter := flag.Lookup("test.run").Value.String(); filter != "ExecCommand" {
t.Skip("use -run ExecCommand to execute a command via the test executable")
}
RootCmd.SetArgs(flag.Args())
require.NoError(t, RootCmd.Execute())
}

View File

@ -0,0 +1,187 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"context"
"os"
"path/filepath"
"time"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/app"
"github.com/pkg/errors"
"github.com/spf13/cobra"
)
var ExportCmd = &cobra.Command{
Use: "export",
Short: "Export data from Mattermost",
Long: "Export data from Mattermost in a format suitable for import into a third-party application or another Mattermost instance",
}
var ScheduleExportCmd = &cobra.Command{
Use: "schedule",
Short: "Schedule an export data job in Mattermost",
Long: "Schedule an export data job in Mattermost (this will run asynchronously via a background worker)",
Example: "export schedule --format=actiance --exportFrom=12345 --timeoutSeconds=12345",
RunE: scheduleExportCmdF,
}
var BulkExportCmd = &cobra.Command{
Use: "bulk [file]",
Short: "Export bulk data.",
Long: "Export data to a file compatible with the Mattermost Bulk Import format.",
Example: "export bulk bulk_data.json",
RunE: bulkExportCmdF,
Args: cobra.ExactArgs(1),
}
func init() {
ScheduleExportCmd.Flags().String("format", "actiance", "The format to export data")
ScheduleExportCmd.Flags().Int64("exportFrom", -1, "The timestamp of the earliest post to export, expressed in seconds since the unix epoch.")
ScheduleExportCmd.Flags().Int("timeoutSeconds", -1, "The maximum number of seconds to wait for the job to complete before timing out.")
BulkExportCmd.Flags().Bool("all-teams", true, "Export all teams from the server.")
BulkExportCmd.Flags().Bool("with-archived-channels", false, "Also exports archived channels.")
BulkExportCmd.Flags().Bool("with-profile-pictures", false, "Also exports profile pictures.")
BulkExportCmd.Flags().Bool("attachments", false, "Also export file attachments.")
BulkExportCmd.Flags().Bool("archive", false, "Outputs a single archive file.")
ExportCmd.AddCommand(ScheduleExportCmd)
ExportCmd.AddCommand(BulkExportCmd)
RootCmd.AddCommand(ExportCmd)
}
func scheduleExportCmdF(command *cobra.Command, args []string) error {
a, err := InitDBCommandContextCobra(command, app.SkipPostInitialization())
if err != nil {
return err
}
defer a.Srv().Shutdown()
if !*a.Config().MessageExportSettings.EnableExport {
return errors.New("ERROR: The message export feature is not enabled")
}
var rctx request.CTX = request.EmptyContext(a.Log())
// for now, format is hard-coded to actiance. In time, we'll have to support other formats and inject them into job data
format, err := command.Flags().GetString("format")
if err != nil {
return errors.New("format flag error")
}
if format != "actiance" {
return errors.New("unsupported export format")
}
startTime, err := command.Flags().GetInt64("exportFrom")
if err != nil {
return errors.New("exportFrom flag error")
}
if startTime < 0 {
return errors.New("exportFrom must be a positive integer")
}
timeoutSeconds, err := command.Flags().GetInt("timeoutSeconds")
if err != nil {
return errors.New("timeoutSeconds error")
}
if timeoutSeconds < 0 {
return errors.New("timeoutSeconds must be a positive integer")
}
if messageExportI := a.MessageExport(); messageExportI != nil {
ctx := context.Background()
if timeoutSeconds > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Second*time.Duration(timeoutSeconds))
defer cancel()
}
rctx = rctx.WithContext(ctx)
job, err := messageExportI.StartSynchronizeJob(rctx, startTime)
if err != nil || job.Status == model.JobStatusError || job.Status == model.JobStatusCanceled {
CommandPrintErrorln("ERROR: Message export job failed. Please check the server logs")
} else {
CommandPrettyPrintln("SUCCESS: Message export job complete")
auditRec := a.MakeAuditRecord(rctx, model.AuditEventScheduleExport, model.AuditStatusSuccess)
auditRec.AddMeta("format", format)
auditRec.AddMeta("start", startTime)
a.LogAuditRec(rctx, auditRec, nil)
}
}
return nil
}
func bulkExportCmdF(command *cobra.Command, args []string) error {
a, err := InitDBCommandContextCobra(command, app.SkipPostInitialization())
if err != nil {
return err
}
defer a.Srv().Shutdown()
rctx := request.EmptyContext(a.Log())
allTeams, err := command.Flags().GetBool("all-teams")
if err != nil {
return errors.Wrap(err, "all-teams flag error")
}
if !allTeams {
return errors.New("Nothing to export. Please specify the --all-teams flag to export all teams.")
}
attachments, err := command.Flags().GetBool("attachments")
if err != nil {
return errors.Wrap(err, "attachments flag error")
}
archive, err := command.Flags().GetBool("archive")
if err != nil {
return errors.Wrap(err, "archive flag error")
}
withArchivedChannels, err := command.Flags().GetBool("with-archived-channels")
if err != nil {
return errors.Wrap(err, "with-archived-channels flag error")
}
includeProfilePictures, err := command.Flags().GetBool("with-profile-pictures")
if err != nil {
return errors.Wrap(err, "with-profile-pictures flag error")
}
fileWriter, err := os.Create(args[0])
if err != nil {
return err
}
defer fileWriter.Close()
outPath, err := filepath.Abs(args[0])
if err != nil {
return err
}
var opts model.BulkExportOpts
opts.IncludeAttachments = attachments
opts.CreateArchive = archive
opts.IncludeArchivedChannels = withArchivedChannels
opts.IncludeProfilePictures = includeProfilePictures
if err := a.BulkExport(rctx, fileWriter, filepath.Dir(outPath), nil /* nil job since it's spawned from CLI */, opts); err != nil {
CommandPrintErrorln(err.Error())
return err
}
auditRec := a.MakeAuditRecord(rctx, model.AuditEventBulkExport, model.AuditStatusSuccess)
auditRec.AddMeta("all_teams", allTeams)
auditRec.AddMeta("file", args[0])
a.LogAuditRec(rctx, auditRec, nil)
return nil
}

View File

@ -0,0 +1,189 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"errors"
"fmt"
"os"
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/app"
)
var ImportCmd = &cobra.Command{
Use: "import",
Short: "Import data.",
}
var SlackImportCmd = &cobra.Command{
Use: "slack [team] [file]",
Short: "Import a team from Slack.",
Long: "Import a team from a Slack export zip file.",
Example: " import slack myteam slack_export.zip",
RunE: slackImportCmdF,
}
var BulkImportCmd = &cobra.Command{
Use: "bulk [file]",
Short: "Import bulk data.",
Long: "Import data from a Mattermost Bulk Import File.",
Example: " import bulk bulk_data.json",
RunE: bulkImportCmdF,
}
func init() {
BulkImportCmd.Flags().Bool("apply", false, "Save the import data to the database. Use with caution - this cannot be reverted.")
BulkImportCmd.Flags().Bool("validate", false, "Validate the import data without making any changes to the system.")
BulkImportCmd.Flags().Int("workers", 2, "How many workers to run whilst doing the import.")
BulkImportCmd.Flags().String("import-path", "", "A path to the data directory to import files from.")
ImportCmd.AddCommand(
BulkImportCmd,
SlackImportCmd,
)
RootCmd.AddCommand(ImportCmd)
}
func slackImportCmdF(command *cobra.Command, args []string) error {
a, err := InitDBCommandContextCobra(command)
if err != nil {
return err
}
defer a.Srv().Shutdown()
rctx := request.EmptyContext(a.Log())
if len(args) != 2 {
return errors.New("Incorrect number of arguments.")
}
team := getTeamFromTeamArg(a, args[0])
if team == nil {
return errors.New("Unable to find team '" + args[0] + "'")
}
fileReader, err := os.Open(args[1])
if err != nil {
return err
}
defer fileReader.Close()
fileInfo, err := fileReader.Stat()
if err != nil {
return err
}
CommandPrettyPrintln("Running Slack Import. This may take a long time for large teams or teams with many messages.")
importErr, log := a.SlackImport(rctx, fileReader, fileInfo.Size(), team.Id)
if importErr != nil {
return err
}
CommandPrettyPrintln("")
CommandPrintln(log.String())
CommandPrettyPrintln("")
CommandPrettyPrintln("Finished Slack Import.")
CommandPrettyPrintln("")
auditRec := a.MakeAuditRecord(rctx, model.AuditEventSlackImport, model.AuditStatusSuccess)
auditRec.AddMeta("team", team)
auditRec.AddMeta("file", args[1])
a.LogAuditRec(rctx, auditRec, nil)
return nil
}
func bulkImportCmdF(command *cobra.Command, args []string) error {
a, err := InitDBCommandContextCobra(command)
if err != nil {
return err
}
defer a.Srv().Shutdown()
rctx := request.EmptyContext(a.Log())
apply, err := command.Flags().GetBool("apply")
if err != nil {
return errors.New("Apply flag error")
}
validate, err := command.Flags().GetBool("validate")
if err != nil {
return errors.New("Validate flag error")
}
workers, err := command.Flags().GetInt("workers")
if err != nil {
return errors.New("Workers flag error")
}
importPath, err := command.Flags().GetString("import-path")
if err != nil {
return errors.New("import-path flag error")
}
if len(args) != 1 {
return errors.New("Incorrect number of arguments.")
}
fileReader, err := os.Open(args[0])
if err != nil {
return err
}
defer fileReader.Close()
if apply && validate {
CommandPrettyPrintln("Use only one of --apply or --validate.")
return nil
}
if apply && !validate {
CommandPrettyPrintln("Running Bulk Import. This may take a long time.")
} else {
CommandPrettyPrintln("Running Bulk Import Data Validation.")
CommandPrettyPrintln("** This checks the validity of the entities in the data file, but does not persist any changes **")
CommandPrettyPrintln("Use the --apply flag to perform the actual data import.")
}
CommandPrettyPrintln("")
if lineNumber, err := a.BulkImportWithPath(rctx, fileReader, nil, true, !apply, workers, importPath); err != nil {
CommandPrintErrorln(err.Error())
if lineNumber != 0 {
CommandPrintErrorln(fmt.Sprintf("Error occurred on data file line %v", lineNumber))
}
return err
}
if apply {
CommandPrettyPrintln("Finished Bulk Import.")
auditRec := a.MakeAuditRecord(rctx, model.AuditEventBulkImport, model.AuditStatusSuccess)
auditRec.AddMeta("file", args[0])
a.LogAuditRec(rctx, auditRec, nil)
} else {
CommandPrettyPrintln("Validation complete. You can now perform the import by rerunning this command with the --apply flag.")
}
return nil
}
func getTeamFromTeamArg(a *app.App, teamArg string) *model.Team {
var team *model.Team
team, err := a.Srv().Store().Team().GetByName(teamArg)
if err != nil {
var t *model.Team
if t, err = a.Srv().Store().Team().Get(teamArg); err == nil {
team = t
}
}
return team
}

View File

@ -0,0 +1,69 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/i18n"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/app"
"github.com/mattermost/mattermost/server/v8/channels/store"
"github.com/mattermost/mattermost/server/v8/channels/store/sqlstore"
"github.com/mattermost/mattermost/server/v8/channels/utils"
"github.com/mattermost/mattermost/server/v8/config"
)
func initDBCommandContextCobra(command *cobra.Command, readOnlyConfigStore bool, options ...app.Option) (*app.App, error) {
a, err := initDBCommandContext(getConfigDSN(command, config.GetEnvironment()), readOnlyConfigStore, options...)
if err != nil {
// Returning an error just prints the usage message, so actually panic
panic(err)
}
a.InitPlugins(request.EmptyContext(a.Log()), *a.Config().PluginSettings.Directory, *a.Config().PluginSettings.ClientDirectory)
a.DoAppMigrations()
return a, nil
}
func InitDBCommandContextCobra(command *cobra.Command, options ...app.Option) (*app.App, error) {
return initDBCommandContextCobra(command, true, options...)
}
func initDBCommandContext(configDSN string, readOnlyConfigStore bool, options ...app.Option) (*app.App, error) {
if err := utils.TranslationsPreInit(); err != nil {
return nil, err
}
model.AppErrorInit(i18n.T)
// The option order is important as app.Config option reads app.StartMetrics option.
options = append(options, app.Config(configDSN, readOnlyConfigStore, nil))
s, err := app.NewServer(options...)
if err != nil {
return nil, err
}
a := app.New(app.ServerConnector(s.Channels()))
if model.BuildEnterpriseReady == "true" {
a.Srv().LoadLicense()
}
return a, nil
}
func initStoreCommandContextCobra(logger mlog.LoggerIFace, command *cobra.Command) (store.Store, error) {
cfgDSN := getConfigDSN(command, config.GetEnvironment())
cfgStore, err := config.NewStoreFromDSN(cfgDSN, true, nil, true)
if err != nil {
return nil, errors.Wrap(err, "failed to load configuration")
}
config := cfgStore.Get()
return sqlstore.New(config.SqlSettings, logger, nil)
}

View File

@ -0,0 +1,74 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"os"
"os/signal"
"syscall"
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/request"
"github.com/mattermost/mattermost/server/v8/channels/app"
"github.com/mattermost/mattermost/server/v8/config"
)
var JobserverCmd = &cobra.Command{
Use: "jobserver",
Short: "Start the Mattermost job server",
RunE: jobserverCmdF,
}
func init() {
JobserverCmd.Flags().Bool("nojobs", false, "Do not run jobs on this jobserver.")
JobserverCmd.Flags().Bool("noschedule", false, "Do not schedule jobs from this jobserver.")
RootCmd.AddCommand(JobserverCmd)
}
func jobserverCmdF(command *cobra.Command, args []string) error {
// Options
noJobs, _ := command.Flags().GetBool("nojobs")
noSchedule, _ := command.Flags().GetBool("noschedule")
// Initialize
a, err := initDBCommandContext(getConfigDSN(command, config.GetEnvironment()), false, app.StartMetrics)
if err != nil {
return err
}
defer a.Srv().Shutdown()
a.Srv().LoadLicense()
rctx := request.EmptyContext(a.Log())
// Run jobs
rctx.Logger().Info("Starting Mattermost job server")
defer rctx.Logger().Info("Stopped Mattermost job server")
if !noJobs {
a.Srv().Jobs.StartWorkers()
defer a.Srv().Jobs.StopWorkers()
}
if !noSchedule {
a.Srv().Jobs.StartSchedulers()
defer a.Srv().Jobs.StopSchedulers()
}
if !noJobs || !noSchedule {
auditRec := a.MakeAuditRecord(rctx, model.AuditEventJobServer, model.AuditStatusSuccess)
a.LogAuditRec(rctx, auditRec, nil)
}
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
<-signalChan
// Cleanup anything that isn't handled by a defer statement
rctx.Logger().Info("Stopping Mattermost job server")
return nil
}

View File

@ -0,0 +1,100 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"flag"
"os"
"testing"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/v8/channels/api4"
"github.com/mattermost/mattermost/server/v8/channels/testlib"
)
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
type TestConfig struct {
TestServiceSettings TestServiceSettings
TestTeamSettings TestTeamSettings
TestClientRequirements TestClientRequirements
TestMessageExportSettings TestMessageExportSettings
}
type TestMessageExportSettings struct {
Enableexport bool
Exportformat string
TestGlobalRelaySettings TestGlobalRelaySettings
}
type TestGlobalRelaySettings struct {
Customertype string
Smtpusername string
Smtppassword string
}
type TestServiceSettings struct {
Siteurl string
Websocketurl string
Licensedfieldlocation string
}
type TestTeamSettings struct {
Sitename string
Maxuserperteam int
}
type TestClientRequirements struct {
Androidlatestversion string
Androidminversion string
Desktoplatestversion string
}
type TestNewConfig struct {
TestNewServiceSettings TestNewServiceSettings
TestNewTeamSettings TestNewTeamSettings
}
type TestNewServiceSettings struct {
SiteUrl *string
UseLetsEncrypt *bool
TLSStrictTransportMaxAge *int64
AllowedThemes []string
}
type TestNewTeamSettings struct {
SiteName *string
MaxUserPerTeam *int
}
type TestPluginSettings struct {
Enable *bool
Directory *string `restricted:"true"`
Plugins map[string]map[string]any
PluginStates map[string]*model.PluginState
SignaturePublicKeyFiles []string
}
func TestMain(m *testing.M) {
// Command tests are run by re-invoking the test binary in question, so avoid creating
// another container when we detect same.
flag.Parse()
if filter := flag.Lookup("test.run").Value.String(); filter == "ExecCommand" {
status := m.Run()
os.Exit(status)
return
}
var options = testlib.HelperOptions{
EnableStore: true,
EnableResources: true,
}
mainHelper = testlib.NewMainHelperWithOptions(&options)
defer mainHelper.Close()
api4.SetMainHelper(mainHelper)
mainHelper.Main(m)
}

View File

@ -0,0 +1,21 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"fmt"
"os"
)
func CommandPrintln(a ...any) (int, error) {
return fmt.Println(a...)
}
func CommandPrintErrorln(a ...any) (int, error) {
return fmt.Fprintln(os.Stderr, a...)
}
func CommandPrettyPrintln(a ...any) (int, error) {
return fmt.Fprintln(os.Stdout, a...)
}

View File

@ -0,0 +1,38 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"os"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/spf13/cobra"
)
type Command = cobra.Command
func Run(args []string) error {
RootCmd.SetArgs(args)
return RootCmd.Execute()
}
var RootCmd = &cobra.Command{
Use: "mattermost",
Short: "Open source, self-hosted Slack-alternative",
Long: `Mattermost offers workplace messaging across web, PC and phones with archiving, search and integration with your existing systems. Documentation available at https://docs.mattermost.com`,
PersistentPreRun: func(cmd *cobra.Command, args []string) {
checkForRootUser()
},
}
func init() {
RootCmd.PersistentFlags().StringP("config", "c", "", "Configuration file to use.")
}
// checkForRootUser logs a warning if the process is running as root
func checkForRootUser() {
if os.Geteuid() == 0 {
mlog.Warn("Running Mattermost as root is not recommended. Please use a non-root user.")
}
}

View File

@ -0,0 +1,151 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"bytes"
"net"
"os"
"os/signal"
"runtime/debug"
"runtime/pprof"
"syscall"
"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/shared/mlog"
"github.com/mattermost/mattermost/server/v8/channels/api4"
"github.com/mattermost/mattermost/server/v8/channels/app"
"github.com/mattermost/mattermost/server/v8/channels/utils"
"github.com/mattermost/mattermost/server/v8/channels/web"
"github.com/mattermost/mattermost/server/v8/channels/wsapi"
"github.com/mattermost/mattermost/server/v8/config"
)
var serverCmd = &cobra.Command{
Use: "server",
Short: "Run the Mattermost server",
RunE: serverCmdF,
SilenceUsage: true,
}
func init() {
RootCmd.AddCommand(serverCmd)
RootCmd.RunE = serverCmdF
}
func serverCmdF(command *cobra.Command, args []string) error {
interruptChan := make(chan os.Signal, 1)
if err := utils.TranslationsPreInit(); err != nil {
return errors.Wrap(err, "unable to load Mattermost translation files")
}
customDefaults, err := loadCustomDefaults()
if err != nil {
mlog.Warn("Error loading custom configuration defaults: " + err.Error())
}
configStore, err := config.NewStoreFromDSN(getConfigDSN(command, config.GetEnvironment()), false, customDefaults, true)
if err != nil {
return errors.Wrap(err, "failed to load configuration")
}
defer configStore.Close()
return runServer(configStore, interruptChan)
}
func runServer(configStore *config.Store, interruptChan chan os.Signal) error {
// Setting the highest traceback level from the code.
// This is done to print goroutines from all threads (see golang.org/issue/13161)
// and also preserve a crash dump for later investigation.
debug.SetTraceback("crash")
options := []app.Option{
// The option order is important as app.Config option reads app.StartMetrics option.
app.StartMetrics,
app.ConfigStore(configStore),
app.RunEssentialJobs,
app.JoinCluster,
}
server, err := app.NewServer(options...)
if err != nil {
mlog.Error(err.Error())
return err
}
defer server.Shutdown()
// We add this after shutdown so that it can be called
// before server shutdown happens as it can close
// the advanced logger and prevent the mlog call from working properly.
defer func() {
// A panic pass-through layer which just logs it
// and sends it upwards.
if x := recover(); x != nil {
var buf bytes.Buffer
pprof.Lookup("goroutine").WriteTo(&buf, 2)
mlog.Error("A panic occurred",
mlog.Any("error", x),
mlog.String("stack", buf.String()))
panic(x)
}
}()
_, err = api4.Init(server)
if err != nil {
mlog.Error(err.Error())
return err
}
wsapi.Init(server)
web.New(server)
err = server.Start()
if err != nil {
mlog.Error(err.Error())
return err
}
notifyReady()
// Wiping off any signal handlers set before.
// This may come from intermediary signal handlers requiring to clean
// up resources before server.Start can finish.
signal.Reset(syscall.SIGINT, syscall.SIGTERM)
// wait for kill signal before attempting to gracefully shutdown
// the running service
signal.Notify(interruptChan, syscall.SIGINT, syscall.SIGTERM)
<-interruptChan
return nil
}
func notifyReady() {
// If the environment vars provide a systemd notification socket,
// notify systemd that the server is ready.
systemdSocket := os.Getenv("NOTIFY_SOCKET")
if systemdSocket != "" {
mlog.Info("Sending systemd READY notification.")
err := sendSystemdReadyNotification(systemdSocket)
if err != nil {
mlog.Error(err.Error())
}
}
}
func sendSystemdReadyNotification(socketPath string) error {
msg := "READY=1"
addr := &net.UnixAddr{
Name: socketPath,
Net: "unixgram",
}
conn, err := net.DialUnix(addr.Net, nil, addr)
if err != nil {
return err
}
defer conn.Close()
_, err = conn.Write([]byte(msg))
return err
}

View File

@ -0,0 +1,152 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"net"
"os"
"syscall"
"testing"
"github.com/stretchr/testify/require"
"github.com/mattermost/mattermost/server/v8/channels/jobs"
"github.com/mattermost/mattermost/server/v8/config"
)
const (
unitTestListeningPort = "localhost:0"
)
//nolint:golint,unused
type ServerTestHelper struct {
disableConfigWatch bool
interruptChan chan os.Signal
originalInterval int
}
//nolint:golint,unused
func SetupServerTest(tb testing.TB) *ServerTestHelper {
if testing.Short() {
tb.SkipNow()
}
// Build a channel that will be used by the server to receive system signals...
interruptChan := make(chan os.Signal, 1)
// ...and sent it immediately a SIGINT value.
// This will make the server loop stop as soon as it started successfully.
interruptChan <- syscall.SIGINT
// Let jobs poll for termination every 0.2s (instead of every 15s by default)
// Otherwise we would have to wait the whole polling duration before the test
// terminates.
originalInterval := jobs.DefaultWatcherPollingInterval
jobs.DefaultWatcherPollingInterval = 200
th := &ServerTestHelper{
disableConfigWatch: true,
interruptChan: interruptChan,
originalInterval: originalInterval,
}
return th
}
//nolint:golint,unused
func (th *ServerTestHelper) TearDownServerTest() {
jobs.DefaultWatcherPollingInterval = th.originalInterval
}
func TestRunServerSuccess(t *testing.T) {
th := SetupServerTest(t)
defer th.TearDownServerTest()
configStore := config.NewTestMemoryStore()
// Use non-default listening port in case another server instance is already running.
cfg := configStore.Get()
*cfg.ServiceSettings.ListenAddress = unitTestListeningPort
cfg.SqlSettings = *mainHelper.GetSQLSettings()
configStore.Set(cfg)
err := runServer(configStore, th.interruptChan)
require.NoError(t, err)
}
func TestRunServerSystemdNotification(t *testing.T) {
th := SetupServerTest(t)
defer th.TearDownServerTest()
// Get a random temporary filename for using as a mock systemd socket
socketFile, err := os.CreateTemp("", "mattermost-systemd-mock-socket-")
if err != nil {
panic(err)
}
socketPath := socketFile.Name()
os.Remove(socketPath)
// Set the socket path in the process environment
originalSocket := os.Getenv("NOTIFY_SOCKET")
os.Setenv("NOTIFY_SOCKET", socketPath)
defer os.Setenv("NOTIFY_SOCKET", originalSocket)
// Open the socket connection
addr := &net.UnixAddr{
Name: socketPath,
Net: "unixgram",
}
connection, err := net.ListenUnixgram("unixgram", addr)
if err != nil {
panic(err)
}
defer connection.Close()
defer os.Remove(socketPath)
// Listen for socket data
socketReader := make(chan string)
go func(ch chan string) {
buffer := make([]byte, 512)
count, readErr := connection.Read(buffer)
if readErr != nil {
panic(readErr)
}
data := buffer[0:count]
ch <- string(data)
}(socketReader)
configStore := config.NewTestMemoryStore()
// Use non-default listening port in case another server instance is already running.
cfg := configStore.Get()
*cfg.ServiceSettings.ListenAddress = unitTestListeningPort
cfg.SqlSettings = *mainHelper.GetSQLSettings()
configStore.Set(cfg)
// Start and stop the server
err = runServer(configStore, th.interruptChan)
require.NoError(t, err)
// Ensure the notification has been sent on the socket and is correct
notification := <-socketReader
require.Equal(t, notification, "READY=1")
}
func TestRunServerNoSystemd(t *testing.T) {
th := SetupServerTest(t)
defer th.TearDownServerTest()
// Temporarily remove any Systemd socket defined in the environment
originalSocket := os.Getenv("NOTIFY_SOCKET")
os.Unsetenv("NOTIFY_SOCKET")
defer os.Setenv("NOTIFY_SOCKET", originalSocket)
configStore := config.NewTestMemoryStore()
// Use non-default listening port in case another server instance is already running.
cfg := configStore.Get()
*cfg.ServiceSettings.ListenAddress = unitTestListeningPort
cfg.SqlSettings = *mainHelper.GetSQLSettings()
configStore.Set(cfg)
err := runServer(configStore, th.interruptChan)
require.NoError(t, err)
}

View File

@ -0,0 +1,154 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"bufio"
"fmt"
"os"
"os/exec"
"os/signal"
"syscall"
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/model"
"github.com/mattermost/mattermost/server/public/shared/i18n"
"github.com/mattermost/mattermost/server/v8/channels/api4"
"github.com/mattermost/mattermost/server/v8/channels/app"
"github.com/mattermost/mattermost/server/v8/channels/wsapi"
)
var TestCmd = &cobra.Command{
Use: "test",
Short: "Testing Commands",
Hidden: true,
}
var RunWebClientTestsCmd = &cobra.Command{
Use: "web_client_tests",
Short: "Run the web client tests",
RunE: webClientTestsCmdF,
}
var RunServerForWebClientTestsCmd = &cobra.Command{
Use: "web_client_tests_server",
Short: "Run the server configured for running the web client tests against it",
RunE: serverForWebClientTestsCmdF,
}
func init() {
TestCmd.AddCommand(
RunWebClientTestsCmd,
RunServerForWebClientTestsCmd,
)
RootCmd.AddCommand(TestCmd)
}
func webClientTestsCmdF(command *cobra.Command, args []string) error {
a, err := InitDBCommandContextCobra(command, app.StartMetrics)
if err != nil {
return err
}
defer a.Srv().Shutdown()
i18n.InitTranslations(*a.Config().LocalizationSettings.DefaultServerLocale, *a.Config().LocalizationSettings.DefaultClientLocale)
serverErr := a.Srv().Start()
if serverErr != nil {
return serverErr
}
_, err = api4.Init(a.Srv())
if err != nil {
return err
}
wsapi.Init(a.Srv())
a.UpdateConfig(setupClientTests)
runWebClientTests()
return nil
}
func serverForWebClientTestsCmdF(command *cobra.Command, args []string) error {
a, err := InitDBCommandContextCobra(command, app.StartMetrics)
if err != nil {
return err
}
defer a.Srv().Shutdown()
i18n.InitTranslations(*a.Config().LocalizationSettings.DefaultServerLocale, *a.Config().LocalizationSettings.DefaultClientLocale)
serverErr := a.Srv().Start()
if serverErr != nil {
return serverErr
}
_, err = api4.Init(a.Srv())
if err != nil {
return err
}
wsapi.Init(a.Srv())
a.UpdateConfig(setupClientTests)
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
<-c
return nil
}
func setupClientTests(cfg *model.Config) {
*cfg.TeamSettings.EnableOpenServer = true
*cfg.ServiceSettings.EnableCommands = false
*cfg.ServiceSettings.EnableCustomEmoji = true
*cfg.ServiceSettings.EnableIncomingWebhooks = false
*cfg.ServiceSettings.EnableOutgoingWebhooks = false
*cfg.ServiceSettings.EnableOutgoingOAuthConnections = false
}
func executeTestCommand(command *exec.Cmd) {
cmdOutPipe, err := command.StdoutPipe()
if err != nil {
CommandPrintErrorln("Failed to run tests")
os.Exit(1)
return
}
cmdErrOutPipe, err := command.StderrPipe()
if err != nil {
CommandPrintErrorln("Failed to run tests")
os.Exit(1)
return
}
cmdOutReader := bufio.NewScanner(cmdOutPipe)
cmdErrOutReader := bufio.NewScanner(cmdErrOutPipe)
go func() {
for cmdOutReader.Scan() {
fmt.Println(cmdOutReader.Text())
}
}()
go func() {
for cmdErrOutReader.Scan() {
fmt.Println(cmdErrOutReader.Text())
}
}()
if err := command.Run(); err != nil {
CommandPrintErrorln("Client Tests failed")
os.Exit(1)
return
}
}
func runWebClientTests() {
if webappDir := os.Getenv("WEBAPP_DIR"); webappDir != "" {
os.Chdir(webappDir)
} else {
os.Chdir("../mattermost-webapp")
}
cmd := exec.Command("npm", "test")
executeTestCommand(cmd)
}

View File

@ -0,0 +1,91 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"bytes"
"encoding/json"
"fmt"
"os"
"reflect"
"sort"
"strings"
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/model"
)
const CustomDefaultsEnvVar = "MM_CUSTOM_DEFAULTS_PATH"
// printStringMap takes a reflect.Value and prints it out alphabetically based on key values, which must be strings.
// This is done recursively if it's a map, and uses the given tab settings.
func printStringMap(value reflect.Value, tabVal int) string {
out := &bytes.Buffer{}
var sortedKeys []string
stringToKeyMap := make(map[string]reflect.Value)
for _, k := range value.MapKeys() {
sortedKeys = append(sortedKeys, k.String())
stringToKeyMap[k.String()] = k
}
sort.Strings(sortedKeys)
for _, keyString := range sortedKeys {
key := stringToKeyMap[keyString]
val := value.MapIndex(key)
if newVal, ok := val.Interface().(map[string]any); !ok {
fmt.Fprintf(out, "%s", strings.Repeat("\t", tabVal))
fmt.Fprintf(out, "%v: \"%v\"\n", key.Interface(), val.Interface())
} else {
fmt.Fprintf(out, "%s", strings.Repeat("\t", tabVal))
fmt.Fprintf(out, "%v:\n", key.Interface())
// going one level in, increase the tab
tabVal++
fmt.Fprintf(out, "%s", printStringMap(reflect.ValueOf(newVal), tabVal))
// coming back one level, decrease the tab
tabVal--
}
}
return out.String()
}
func getConfigDSN(command *cobra.Command, env map[string]string) string {
configDSN, _ := command.Flags().GetString("config")
// Config not supplied in flag, check env
if configDSN == "" {
configDSN = env["MM_CONFIG"]
}
// Config not supplied in env or flag use default
if configDSN == "" {
configDSN = "config.json"
}
return configDSN
}
func loadCustomDefaults() (*model.Config, error) {
customDefaultsPath := os.Getenv(CustomDefaultsEnvVar)
if customDefaultsPath == "" {
return nil, nil
}
file, err := os.Open(customDefaultsPath)
if err != nil {
return nil, fmt.Errorf("unable to open custom defaults file at %q: %w", customDefaultsPath, err)
}
defer file.Close()
var customDefaults *model.Config
err = json.NewDecoder(file).Decode(&customDefaults)
if err != nil {
return nil, fmt.Errorf("unable to decode custom defaults configuration: %w", err)
}
return customDefaults, nil
}

View File

@ -0,0 +1,72 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"reflect"
"sort"
"strings"
"testing"
)
func TestPrintMap(t *testing.T) {
inputCases := []any{
map[string]any{
"CustomerType": "A9",
"SmtpUsername": "",
"SmtpPassword": "",
"EmailAddress": "",
},
map[string]any{
"EnableExport": false,
"ExportFormat": "actiance",
"DailyRunTime": "01:00",
"GlobalRelaySettings": map[string]any{
"CustomerType": "A9",
"SmtpUsername": "",
"SmtpPassword": "",
"EmailAddress": "",
},
},
}
outputCases := []string{
"CustomerType: \"A9\"\nSmtpUsername: \"\"\nSmtpPassword: \"\"\nEmailAddress: \"\"\n",
"EnableExport: \"false\"\nExportFormat: \"actiance\"\nDailyRunTime: \"01:00\"\nGlobalRelaySettings:\n\t CustomerType: \"A9\"\n\tSmtpUsername: \"\"\n\tSmtpPassword: \"\"\n\tEmailAddress: \"\"\n",
}
cases := []struct {
Name string
Input reflect.Value
Expected string
}{
{
Name: "Basic print",
Input: reflect.ValueOf(inputCases[0]),
Expected: outputCases[0],
},
{
Name: "Complex print",
Input: reflect.ValueOf(inputCases[1]),
Expected: outputCases[1],
},
}
for _, test := range cases {
t.Run(test.Name, func(t *testing.T) {
res := printStringMap(test.Input, 0)
// create two slice of string formed by splitting our strings on \n
slice1 := strings.Split(res, "\n")
slice2 := strings.Split(res, "\n")
sort.Strings(slice1)
sort.Strings(slice2)
if !reflect.DeepEqual(slice1, slice2) {
t.Errorf("got '%#v' want '%#v", slice1, slice2)
}
})
}
}

View File

@ -0,0 +1,30 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"github.com/spf13/cobra"
"github.com/mattermost/mattermost/server/public/model"
)
var VersionCmd = &cobra.Command{
Use: "version",
Short: "Display version information",
RunE: versionCmdF,
}
func init() {
RootCmd.AddCommand(VersionCmd)
}
func versionCmdF(command *cobra.Command, args []string) error {
CommandPrintln("Version: " + model.CurrentVersion)
CommandPrintln("Build Number: " + model.BuildNumber)
CommandPrintln("Build Date: " + model.BuildDate)
CommandPrintln("Build Hash: " + model.BuildHash)
CommandPrintln("Build Enterprise Ready: " + model.BuildEnterpriseReady)
return nil
}

View File

@ -0,0 +1,19 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package commands
import (
"testing"
)
func TestVersion(t *testing.T) {
if testing.Short() {
t.Skip("skipping version test in short mode")
}
th := SetupWithStoreMock(t)
defer th.TearDown()
th.CheckCommand(t, "version")
}

23
cmd/mattermost/main.go Normal file
View File

@ -0,0 +1,23 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
package main
import (
"os"
"github.com/mattermost/mattermost/server/v8/cmd/mattermost/commands"
// Import and register app layer slash commands
_ "github.com/mattermost/mattermost/server/v8/channels/app/slashcommands"
// Plugins
_ "github.com/mattermost/mattermost/server/v8/channels/app/oauthproviders/gitlab"
// Enterprise Imports
_ "github.com/mattermost/mattermost/server/v8/enterprise"
)
func main() {
if err := commands.Run(os.Args[1:]); err != nil {
os.Exit(1)
}
}

View File

@ -0,0 +1,19 @@
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.
//go:build maincoverage
package main
import (
"testing"
)
// TestRunMain can be used to track code coverage in integration tests.
// To run this:
// go test -coverpkg="<>" -ldflags '<>' -tags maincoverage -c ./cmd/mattermost/
// ./mattermost.test -test.run="^TestRunMain$" -test.coverprofile=coverage.out
// And then run your integration tests.
func TestRunMain(t *testing.T) {
main()
}

19
vendor/github.com/mattermost/go-i18n/LICENSE generated vendored Normal file
View File

@ -0,0 +1,19 @@
Copyright (c) 2014 Nick Snyder https://github.com/nicksnyder
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -0,0 +1,453 @@
// Package bundle manages translations for multiple languages.
package bundle
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
"sync"
"unicode"
"github.com/mattermost/go-i18n/i18n/language"
"github.com/mattermost/go-i18n/i18n/translation"
toml "github.com/pelletier/go-toml"
"gopkg.in/yaml.v2"
)
// TranslateFunc is a copy of i18n.TranslateFunc to avoid a circular dependency.
type TranslateFunc func(translationID string, args ...interface{}) string
// Bundle stores the translations for multiple languages.
type Bundle struct {
// The primary translations for a language tag and translation id.
translations map[string]map[string]translation.Translation
// Translations that can be used when an exact language match is not possible.
fallbackTranslations map[string]map[string]translation.Translation
sync.RWMutex
}
// New returns an empty bundle.
func New() *Bundle {
return &Bundle{
translations: make(map[string]map[string]translation.Translation),
fallbackTranslations: make(map[string]map[string]translation.Translation),
}
}
// MustLoadTranslationFile is similar to LoadTranslationFile
// except it panics if an error happens.
func (b *Bundle) MustLoadTranslationFile(filename string) {
if err := b.LoadTranslationFile(filename); err != nil {
panic(err)
}
}
// LoadTranslationFile loads the translations from filename into memory.
//
// The language that the translations are associated with is parsed from the filename (e.g. en-US.json).
//
// Generally you should load translation files once during your program's initialization.
func (b *Bundle) LoadTranslationFile(filename string) error {
buf, err := ioutil.ReadFile(filename)
if err != nil {
return err
}
return b.ParseTranslationFileBytes(filename, buf)
}
// ParseTranslationFileBytes is similar to LoadTranslationFile except it parses the bytes in buf.
//
// It is useful for parsing translation files embedded with go-bindata.
func (b *Bundle) ParseTranslationFileBytes(filename string, buf []byte) error {
basename := filepath.Base(filename)
langs := language.Parse(basename)
switch l := len(langs); {
case l == 0:
return fmt.Errorf("no language found in %q", basename)
case l > 1:
return fmt.Errorf("multiple languages found in filename %q: %v; expected one", basename, langs)
}
translations, err := parseTranslations(filename, buf)
if err != nil {
return err
}
b.AddTranslation(langs[0], translations...)
return nil
}
func parseTranslations(filename string, buf []byte) ([]translation.Translation, error) {
if len(buf) == 0 {
return []translation.Translation{}, nil
}
ext := filepath.Ext(filename)
// `github.com/pelletier/go-toml` lacks an Unmarshal function,
// so we should parse TOML separately.
if ext == ".toml" {
tree, err := toml.LoadReader(bytes.NewReader(buf))
if err != nil {
return nil, err
}
m := make(map[string]map[string]interface{})
for k, v := range tree.ToMap() {
m[k] = v.(map[string]interface{})
}
return parseFlatFormat(m)
}
// Then parse other formats.
if isStandardFormat(ext, buf) {
var standardFormat []map[string]interface{}
if err := unmarshal(ext, buf, &standardFormat); err != nil {
return nil, fmt.Errorf("failed to unmarshal %v: %v", filename, err)
}
return parseStandardFormat(standardFormat)
}
var flatFormat map[string]map[string]interface{}
if err := unmarshal(ext, buf, &flatFormat); err != nil {
return nil, fmt.Errorf("failed to unmarshal %v: %v", filename, err)
}
return parseFlatFormat(flatFormat)
}
func isStandardFormat(ext string, buf []byte) bool {
buf = deleteLeadingComments(ext, buf)
firstRune := rune(buf[0])
return (ext == ".json" && firstRune == '[') || (ext == ".yaml" && firstRune == '-')
}
// deleteLeadingComments deletes leading newlines and comments in buf.
// It only works for ext == ".yaml".
func deleteLeadingComments(ext string, buf []byte) []byte {
if ext != ".yaml" {
return buf
}
for {
buf = bytes.TrimLeftFunc(buf, unicode.IsSpace)
if buf[0] == '#' {
buf = deleteLine(buf)
} else {
break
}
}
return buf
}
func deleteLine(buf []byte) []byte {
index := bytes.IndexRune(buf, '\n')
if index == -1 { // If there is only one line without newline ...
return nil // ... delete it and return nothing.
}
if index == len(buf)-1 { // If there is only one line with newline ...
return nil // ... do the same as above.
}
return buf[index+1:]
}
// unmarshal finds an appropriate unmarshal function for ext
// (extension of filename) and unmarshals buf to out. out must be a pointer.
func unmarshal(ext string, buf []byte, out interface{}) error {
switch ext {
case ".json":
return json.Unmarshal(buf, out)
case ".yaml":
return yaml.Unmarshal(buf, out)
}
return fmt.Errorf("unsupported file extension %v", ext)
}
func parseStandardFormat(data []map[string]interface{}) ([]translation.Translation, error) {
translations := make([]translation.Translation, 0, len(data))
for i, translationData := range data {
t, err := translation.NewTranslation(translationData)
if err != nil {
return nil, fmt.Errorf("unable to parse translation #%d because %s\n%v", i, err, translationData)
}
translations = append(translations, t)
}
return translations, nil
}
// parseFlatFormat just converts data from flat format to standard format
// and passes it to parseStandardFormat.
//
// Flat format logic:
// key of data must be a string and data[key] must be always map[string]interface{},
// but if there is only "other" key in it then it is non-plural, else plural.
func parseFlatFormat(data map[string]map[string]interface{}) ([]translation.Translation, error) {
var standardFormatData []map[string]interface{}
for id, translationData := range data {
dataObject := make(map[string]interface{})
dataObject["id"] = id
if len(translationData) == 1 { // non-plural form
_, otherExists := translationData["other"]
if otherExists {
dataObject["translation"] = translationData["other"]
}
} else { // plural form
dataObject["translation"] = translationData
}
standardFormatData = append(standardFormatData, dataObject)
}
return parseStandardFormat(standardFormatData)
}
// AddTranslation adds translations for a language.
//
// It is useful if your translations are in a format not supported by LoadTranslationFile.
func (b *Bundle) AddTranslation(lang *language.Language, translations ...translation.Translation) {
b.Lock()
defer b.Unlock()
if b.translations[lang.Tag] == nil {
b.translations[lang.Tag] = make(map[string]translation.Translation, len(translations))
}
currentTranslations := b.translations[lang.Tag]
for _, newTranslation := range translations {
if currentTranslation := currentTranslations[newTranslation.ID()]; currentTranslation != nil {
currentTranslations[newTranslation.ID()] = currentTranslation.Merge(newTranslation)
} else {
currentTranslations[newTranslation.ID()] = newTranslation
}
}
// lang can provide translations for less specific language tags.
for _, tag := range lang.MatchingTags() {
b.fallbackTranslations[tag] = currentTranslations
}
}
// Translations returns all translations in the bundle.
func (b *Bundle) Translations() map[string]map[string]translation.Translation {
t := make(map[string]map[string]translation.Translation)
b.RLock()
for tag, translations := range b.translations {
t[tag] = make(map[string]translation.Translation)
for id, translation := range translations {
t[tag][id] = translation
}
}
b.RUnlock()
return t
}
// LanguageTags returns the tags of all languages that that have been added.
func (b *Bundle) LanguageTags() []string {
var tags []string
b.RLock()
for k := range b.translations {
tags = append(tags, k)
}
b.RUnlock()
return tags
}
// LanguageTranslationIDs returns the ids of all translations that have been added for a given language.
func (b *Bundle) LanguageTranslationIDs(languageTag string) []string {
var ids []string
b.RLock()
for id := range b.translations[languageTag] {
ids = append(ids, id)
}
b.RUnlock()
return ids
}
// MustTfunc is similar to Tfunc except it panics if an error happens.
func (b *Bundle) MustTfunc(pref string, prefs ...string) TranslateFunc {
tfunc, err := b.Tfunc(pref, prefs...)
if err != nil {
panic(err)
}
return tfunc
}
// MustTfuncAndLanguage is similar to TfuncAndLanguage except it panics if an error happens.
func (b *Bundle) MustTfuncAndLanguage(pref string, prefs ...string) (TranslateFunc, *language.Language) {
tfunc, language, err := b.TfuncAndLanguage(pref, prefs...)
if err != nil {
panic(err)
}
return tfunc, language
}
// Tfunc is similar to TfuncAndLanguage except is doesn't return the Language.
func (b *Bundle) Tfunc(pref string, prefs ...string) (TranslateFunc, error) {
tfunc, _, err := b.TfuncAndLanguage(pref, prefs...)
return tfunc, err
}
// TfuncAndLanguage returns a TranslateFunc for the first Language that
// has a non-zero number of translations in the bundle.
//
// The returned Language matches the the first language preference that could be satisfied,
// but this may not strictly match the language of the translations used to satisfy that preference.
//
// For example, the user may request "zh". If there are no translations for "zh" but there are translations
// for "zh-cn", then the translations for "zh-cn" will be used but the returned Language will be "zh".
//
// It can parse languages from Accept-Language headers (RFC 2616),
// but it assumes weights are monotonically decreasing.
func (b *Bundle) TfuncAndLanguage(pref string, prefs ...string) (TranslateFunc, *language.Language, error) {
lang := b.supportedLanguage(pref, prefs...)
var err error
if lang == nil {
err = fmt.Errorf("no supported languages found %#v", append(prefs, pref))
}
return func(translationID string, args ...interface{}) string {
return b.translate(lang, translationID, args...)
}, lang, err
}
// supportedLanguage returns the first language which
// has a non-zero number of translations in the bundle.
func (b *Bundle) supportedLanguage(pref string, prefs ...string) *language.Language {
lang := b.translatedLanguage(pref)
if lang == nil {
for _, pref := range prefs {
lang = b.translatedLanguage(pref)
if lang != nil {
break
}
}
}
return lang
}
func (b *Bundle) translatedLanguage(src string) *language.Language {
langs := language.Parse(src)
b.RLock()
defer b.RUnlock()
for _, lang := range langs {
if len(b.translations[lang.Tag]) > 0 ||
len(b.fallbackTranslations[lang.Tag]) > 0 {
return lang
}
}
return nil
}
func (b *Bundle) translate(lang *language.Language, translationID string, args ...interface{}) string {
if lang == nil {
return translationID
}
translation := b.translation(lang, translationID)
if translation == nil {
return translationID
}
var data interface{}
var count interface{}
if argc := len(args); argc > 0 {
if isNumber(args[0]) {
count = args[0]
if argc > 1 {
data = args[1]
}
} else {
data = args[0]
}
}
if count != nil {
if data == nil {
data = map[string]interface{}{"Count": count}
} else {
dataMap := toMap(data)
dataMap["Count"] = count
data = dataMap
}
} else {
dataMap := toMap(data)
if c, ok := dataMap["Count"]; ok {
count = c
}
}
p, _ := lang.Plural(count)
template := translation.Template(p)
if template == nil {
if p == language.Other {
return translationID
}
countInt, ok := count.(int)
if ok && countInt > 1 {
template = translation.Template(language.Other)
}
}
if template == nil {
return translationID
}
s := template.Execute(data)
if s == "" {
return translationID
}
return s
}
func (b *Bundle) translation(lang *language.Language, translationID string) translation.Translation {
b.RLock()
defer b.RUnlock()
translations := b.translations[lang.Tag]
if translations == nil {
translations = b.fallbackTranslations[lang.Tag]
if translations == nil {
return nil
}
}
return translations[translationID]
}
func isNumber(n interface{}) bool {
switch n.(type) {
case int, int8, int16, int32, int64, string:
return true
}
return false
}
func toMap(input interface{}) map[string]interface{} {
if data, ok := input.(map[string]interface{}); ok {
return data
}
v := reflect.ValueOf(input)
switch v.Kind() {
case reflect.Ptr:
return toMap(v.Elem().Interface())
case reflect.Struct:
return structToMap(v)
default:
return nil
}
}
// Converts the top level of a struct to a map[string]interface{}.
// Code inspired by github.com/fatih/structs.
func structToMap(v reflect.Value) map[string]interface{} {
out := make(map[string]interface{})
t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.PkgPath != "" {
// unexported field. skip.
continue
}
out[field.Name] = v.FieldByName(field.Name).Interface()
}
return out
}

158
vendor/github.com/mattermost/go-i18n/i18n/i18n.go generated vendored Normal file
View File

@ -0,0 +1,158 @@
// Package i18n supports string translations with variable substitution and CLDR pluralization.
// It is intended to be used in conjunction with the goi18n command, although that is not strictly required.
//
// Initialization
//
// Your Go program should load translations during its initialization.
// i18n.MustLoadTranslationFile("path/to/fr-FR.all.json")
// If your translations are in a file format not supported by (Must)?LoadTranslationFile,
// then you can use the AddTranslation function to manually add translations.
//
// Fetching a translation
//
// Use Tfunc or MustTfunc to fetch a TranslateFunc that will return the translated string for a specific language.
// func handleRequest(w http.ResponseWriter, r *http.Request) {
// cookieLang := r.Cookie("lang")
// acceptLang := r.Header.Get("Accept-Language")
// defaultLang = "en-US" // known valid language
// T, err := i18n.Tfunc(cookieLang, acceptLang, defaultLang)
// fmt.Println(T("Hello world"))
// }
//
// Usually it is a good idea to identify strings by a generic id rather than the English translation,
// but the rest of this documentation will continue to use the English translation for readability.
// T("Hello world") // ok
// T("programGreeting") // better!
//
// Variables
//
// TranslateFunc supports strings that have variables using the text/template syntax.
// T("Hello {{.Person}}", map[string]interface{}{
// "Person": "Bob",
// })
//
// Pluralization
//
// TranslateFunc supports the pluralization of strings using the CLDR pluralization rules defined here:
// http://www.unicode.org/cldr/charts/latest/supplemental/language_plural_rules.html
// T("You have {{.Count}} unread emails.", 2)
// T("I am {{.Count}} meters tall.", "1.7")
//
// Plural strings may also have variables.
// T("{{.Person}} has {{.Count}} unread emails", 2, map[string]interface{}{
// "Person": "Bob",
// })
//
// Sentences with multiple plural components can be supported with nesting.
// T("{{.Person}} has {{.Count}} unread emails in the past {{.Timeframe}}.", 3, map[string]interface{}{
// "Person": "Bob",
// "Timeframe": T("{{.Count}} days", 2),
// })
//
// Templates
//
// You can use the .Funcs() method of a text/template or html/template to register a TranslateFunc
// for usage inside of that template.
package i18n
import (
"github.com/mattermost/go-i18n/i18n/bundle"
"github.com/mattermost/go-i18n/i18n/language"
"github.com/mattermost/go-i18n/i18n/translation"
)
// TranslateFunc returns the translation of the string identified by translationID.
//
// If there is no translation for translationID, then the translationID itself is returned.
// This makes it easy to identify missing translations in your app.
//
// If translationID is a non-plural form, then the first variadic argument may be a map[string]interface{}
// or struct that contains template data.
//
// If translationID is a plural form, the function accepts two parameter signatures
// 1. T(count int, data struct{})
// The first variadic argument must be an integer type
// (int, int8, int16, int32, int64) or a float formatted as a string (e.g. "123.45").
// The second variadic argument may be a map[string]interface{} or struct{} that contains template data.
// 2. T(data struct{})
// data must be a struct{} or map[string]interface{} that contains a Count field and the template data,
// Count field must be an integer type (int, int8, int16, int32, int64)
// or a float formatted as a string (e.g. "123.45").
type TranslateFunc func(translationID string, args ...interface{}) string
// IdentityTfunc returns a TranslateFunc that always returns the translationID passed to it.
//
// It is a useful placeholder when parsing a text/template or html/template
// before the actual Tfunc is available.
func IdentityTfunc() TranslateFunc {
return func(translationID string, args ...interface{}) string {
return translationID
}
}
var defaultBundle = bundle.New()
// MustLoadTranslationFile is similar to LoadTranslationFile
// except it panics if an error happens.
func MustLoadTranslationFile(filename string) {
defaultBundle.MustLoadTranslationFile(filename)
}
// LoadTranslationFile loads the translations from filename into memory.
//
// The language that the translations are associated with is parsed from the filename (e.g. en-US.json).
//
// Generally you should load translation files once during your program's initialization.
func LoadTranslationFile(filename string) error {
return defaultBundle.LoadTranslationFile(filename)
}
// ParseTranslationFileBytes is similar to LoadTranslationFile except it parses the bytes in buf.
//
// It is useful for parsing translation files embedded with go-bindata.
func ParseTranslationFileBytes(filename string, buf []byte) error {
return defaultBundle.ParseTranslationFileBytes(filename, buf)
}
// AddTranslation adds translations for a language.
//
// It is useful if your translations are in a format not supported by LoadTranslationFile.
func AddTranslation(lang *language.Language, translations ...translation.Translation) {
defaultBundle.AddTranslation(lang, translations...)
}
// LanguageTags returns the tags of all languages that have been added.
func LanguageTags() []string {
return defaultBundle.LanguageTags()
}
// LanguageTranslationIDs returns the ids of all translations that have been added for a given language.
func LanguageTranslationIDs(languageTag string) []string {
return defaultBundle.LanguageTranslationIDs(languageTag)
}
// MustTfunc is similar to Tfunc except it panics if an error happens.
func MustTfunc(languageSource string, languageSources ...string) TranslateFunc {
return TranslateFunc(defaultBundle.MustTfunc(languageSource, languageSources...))
}
// Tfunc returns a TranslateFunc that will be bound to the first language which
// has a non-zero number of translations.
//
// It can parse languages from Accept-Language headers (RFC 2616).
func Tfunc(languageSource string, languageSources ...string) (TranslateFunc, error) {
tfunc, err := defaultBundle.Tfunc(languageSource, languageSources...)
return TranslateFunc(tfunc), err
}
// MustTfuncAndLanguage is similar to TfuncAndLanguage except it panics if an error happens.
func MustTfuncAndLanguage(languageSource string, languageSources ...string) (TranslateFunc, *language.Language) {
tfunc, lang := defaultBundle.MustTfuncAndLanguage(languageSource, languageSources...)
return TranslateFunc(tfunc), lang
}
// TfuncAndLanguage is similar to Tfunc except it also returns the language which TranslateFunc is bound to.
func TfuncAndLanguage(languageSource string, languageSources ...string) (TranslateFunc, *language.Language, error) {
tfunc, lang, err := defaultBundle.TfuncAndLanguage(languageSource, languageSources...)
return TranslateFunc(tfunc), lang, err
}

View File

@ -0,0 +1,99 @@
// Package language defines languages that implement CLDR pluralization.
package language
import (
"fmt"
"strings"
)
// Language is a written human language.
type Language struct {
// Tag uniquely identifies the language as defined by RFC 5646.
//
// Most language tags are a two character language code (ISO 639-1)
// optionally followed by a dash and a two character country code (ISO 3166-1).
// (e.g. en, pt-br)
Tag string
*PluralSpec
}
func (l *Language) String() string {
return l.Tag
}
// MatchingTags returns the set of language tags that map to this Language.
// e.g. "zh-hans-cn" yields {"zh", "zh-hans", "zh-hans-cn"}
// BUG: This should be computed once and stored as a field on Language for efficiency,
// but this would require changing how Languages are constructed.
func (l *Language) MatchingTags() []string {
parts := strings.Split(l.Tag, "-")
var prefix, matches []string
for _, part := range parts {
prefix = append(prefix, part)
match := strings.Join(prefix, "-")
matches = append(matches, match)
}
return matches
}
// Parse returns a slice of supported languages found in src or nil if none are found.
// It can parse language tags and Accept-Language headers.
func Parse(src string) []*Language {
var langs []*Language
start := 0
for end, chr := range src {
switch chr {
case ',', ';', '.':
tag := strings.TrimSpace(src[start:end])
if spec := GetPluralSpec(tag); spec != nil {
langs = append(langs, &Language{NormalizeTag(tag), spec})
}
start = end + 1
}
}
if start > 0 {
tag := strings.TrimSpace(src[start:])
if spec := GetPluralSpec(tag); spec != nil {
langs = append(langs, &Language{NormalizeTag(tag), spec})
}
return dedupe(langs)
}
if spec := GetPluralSpec(src); spec != nil {
langs = append(langs, &Language{NormalizeTag(src), spec})
}
return langs
}
func dedupe(langs []*Language) []*Language {
found := make(map[string]struct{}, len(langs))
deduped := make([]*Language, 0, len(langs))
for _, lang := range langs {
if _, ok := found[lang.Tag]; !ok {
found[lang.Tag] = struct{}{}
deduped = append(deduped, lang)
}
}
return deduped
}
// MustParse is similar to Parse except it panics instead of retuning a nil Language.
func MustParse(src string) []*Language {
langs := Parse(src)
if len(langs) == 0 {
panic(fmt.Errorf("unable to parse language from %q", src))
}
return langs
}
// Add adds support for a new language.
func Add(l *Language) {
tag := NormalizeTag(l.Tag)
pluralSpecs[tag] = l.PluralSpec
}
// NormalizeTag returns a language tag with all lower-case characters
// and dashes "-" instead of underscores "_"
func NormalizeTag(tag string) string {
tag = strings.ToLower(tag)
return strings.Replace(tag, "_", "-", -1)
}

View File

@ -0,0 +1,119 @@
package language
import (
"fmt"
"strconv"
"strings"
)
// Operands is a representation of http://unicode.org/reports/tr35/tr35-numbers.html#Operands
type Operands struct {
N float64 // absolute value of the source number (integer and decimals)
I int64 // integer digits of n
V int64 // number of visible fraction digits in n, with trailing zeros
W int64 // number of visible fraction digits in n, without trailing zeros
F int64 // visible fractional digits in n, with trailing zeros
T int64 // visible fractional digits in n, without trailing zeros
}
// NequalsAny returns true if o represents an integer equal to any of the arguments.
func (o *Operands) NequalsAny(any ...int64) bool {
for _, i := range any {
if o.I == i && o.T == 0 {
return true
}
}
return false
}
// NmodEqualsAny returns true if o represents an integer equal to any of the arguments modulo mod.
func (o *Operands) NmodEqualsAny(mod int64, any ...int64) bool {
modI := o.I % mod
for _, i := range any {
if modI == i && o.T == 0 {
return true
}
}
return false
}
// NinRange returns true if o represents an integer in the closed interval [from, to].
func (o *Operands) NinRange(from, to int64) bool {
return o.T == 0 && from <= o.I && o.I <= to
}
// NmodInRange returns true if o represents an integer in the closed interval [from, to] modulo mod.
func (o *Operands) NmodInRange(mod, from, to int64) bool {
modI := o.I % mod
return o.T == 0 && from <= modI && modI <= to
}
func newOperands(v interface{}) (*Operands, error) {
switch v := v.(type) {
case int:
return newOperandsInt64(int64(v)), nil
case int8:
return newOperandsInt64(int64(v)), nil
case int16:
return newOperandsInt64(int64(v)), nil
case int32:
return newOperandsInt64(int64(v)), nil
case int64:
return newOperandsInt64(v), nil
case string:
return newOperandsString(v)
case float32, float64:
return nil, fmt.Errorf("floats should be formatted into a string")
default:
return nil, fmt.Errorf("invalid type %T; expected integer or string", v)
}
}
func newOperandsInt64(i int64) *Operands {
if i < 0 {
i = -i
}
return &Operands{float64(i), i, 0, 0, 0, 0}
}
func newOperandsString(s string) (*Operands, error) {
if s[0] == '-' {
s = s[1:]
}
n, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, err
}
ops := &Operands{N: n}
parts := strings.SplitN(s, ".", 2)
ops.I, err = strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return nil, err
}
if len(parts) == 1 {
return ops, nil
}
fraction := parts[1]
ops.V = int64(len(fraction))
for i := ops.V - 1; i >= 0; i-- {
if fraction[i] != '0' {
ops.W = i + 1
break
}
}
if ops.V > 0 {
f, err := strconv.ParseInt(fraction, 10, 0)
if err != nil {
return nil, err
}
ops.F = f
}
if ops.W > 0 {
t, err := strconv.ParseInt(fraction[:ops.W], 10, 0)
if err != nil {
return nil, err
}
ops.T = t
}
return ops, nil
}

View File

@ -0,0 +1,40 @@
package language
import (
"fmt"
)
// Plural represents a language pluralization form as defined here:
// http://cldr.unicode.org/index/cldr-spec/plural-rules
type Plural string
// All defined plural categories.
const (
Invalid Plural = "invalid"
Zero = "zero"
One = "one"
Two = "two"
Few = "few"
Many = "many"
Other = "other"
)
// NewPlural returns src as a Plural
// or Invalid and a non-nil error if src is not a valid Plural.
func NewPlural(src string) (Plural, error) {
switch src {
case "zero":
return Zero, nil
case "one":
return One, nil
case "two":
return Two, nil
case "few":
return Few, nil
case "many":
return Many, nil
case "other":
return Other, nil
}
return Invalid, fmt.Errorf("invalid plural category %s", src)
}

View File

@ -0,0 +1,75 @@
package language
import "strings"
// PluralSpec defines the CLDR plural rules for a language.
// http://www.unicode.org/cldr/charts/latest/supplemental/language_plural_rules.html
// http://unicode.org/reports/tr35/tr35-numbers.html#Operands
type PluralSpec struct {
Plurals map[Plural]struct{}
PluralFunc func(*Operands) Plural
}
var pluralSpecs = make(map[string]*PluralSpec)
func normalizePluralSpecID(id string) string {
id = strings.Replace(id, "_", "-", -1)
id = strings.ToLower(id)
return id
}
// RegisterPluralSpec registers a new plural spec for the language ids.
func RegisterPluralSpec(ids []string, ps *PluralSpec) {
for _, id := range ids {
id = normalizePluralSpecID(id)
pluralSpecs[id] = ps
}
}
// Plural returns the plural category for number as defined by
// the language's CLDR plural rules.
func (ps *PluralSpec) Plural(number interface{}) (Plural, error) {
ops, err := newOperands(number)
if err != nil {
return Invalid, err
}
return ps.PluralFunc(ops), nil
}
// GetPluralSpec returns the PluralSpec that matches the longest prefix of tag.
// It returns nil if no PluralSpec matches tag.
func GetPluralSpec(tag string) *PluralSpec {
tag = NormalizeTag(tag)
subtag := tag
for {
if spec := pluralSpecs[subtag]; spec != nil {
return spec
}
end := strings.LastIndex(subtag, "-")
if end == -1 {
return nil
}
subtag = subtag[:end]
}
}
func newPluralSet(plurals ...Plural) map[Plural]struct{} {
set := make(map[Plural]struct{}, len(plurals))
for _, plural := range plurals {
set[plural] = struct{}{}
}
return set
}
func intInRange(i, from, to int64) bool {
return from <= i && i <= to
}
func intEqualsAny(i int64, any ...int64) bool {
for _, a := range any {
if i == a {
return true
}
}
return false
}

View File

@ -0,0 +1,557 @@
package language
// This file is generated by i18n/language/codegen/generate.sh
func init() {
RegisterPluralSpec([]string{"bm", "bo", "dz", "id", "ig", "ii", "in", "ja", "jbo", "jv", "jw", "kde", "kea", "km", "ko", "lkt", "lo", "ms", "my", "nqo", "root", "sah", "ses", "sg", "th", "to", "vi", "wo", "yo", "yue", "zh"}, &PluralSpec{
Plurals: newPluralSet(Other),
PluralFunc: func(ops *Operands) Plural {
return Other
},
})
RegisterPluralSpec([]string{"am", "as", "bn", "fa", "gu", "hi", "kn", "mr", "zu"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 0 or n = 1
if intEqualsAny(ops.I, 0) ||
ops.NequalsAny(1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"ff", "fr", "hy", "kab"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 0,1
if intEqualsAny(ops.I, 0, 1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"pt"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 0..1
if intInRange(ops.I, 0, 1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"ast", "ca", "de", "en", "et", "fi", "fy", "gl", "it", "ji", "nl", "sv", "sw", "ur", "yi"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 1 and v = 0
if intEqualsAny(ops.I, 1) && intEqualsAny(ops.V, 0) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"si"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 0,1 or i = 0 and f = 1
if ops.NequalsAny(0, 1) ||
intEqualsAny(ops.I, 0) && intEqualsAny(ops.F, 1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"ak", "bh", "guw", "ln", "mg", "nso", "pa", "ti", "wa"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 0..1
if ops.NinRange(0, 1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"tzm"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 0..1 or n = 11..99
if ops.NinRange(0, 1) ||
ops.NinRange(11, 99) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"af", "asa", "az", "bem", "bez", "bg", "brx", "ce", "cgg", "chr", "ckb", "dv", "ee", "el", "eo", "es", "eu", "fo", "fur", "gsw", "ha", "haw", "hu", "jgo", "jmc", "ka", "kaj", "kcg", "kk", "kkj", "kl", "ks", "ksb", "ku", "ky", "lb", "lg", "mas", "mgo", "ml", "mn", "nah", "nb", "nd", "ne", "nn", "nnh", "no", "nr", "ny", "nyn", "om", "or", "os", "pap", "ps", "rm", "rof", "rwk", "saq", "sdh", "seh", "sn", "so", "sq", "ss", "ssy", "st", "syr", "ta", "te", "teo", "tig", "tk", "tn", "tr", "ts", "ug", "uz", "ve", "vo", "vun", "wae", "xh", "xog"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 1
if ops.NequalsAny(1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"da"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 1 or t != 0 and i = 0,1
if ops.NequalsAny(1) ||
!intEqualsAny(ops.T, 0) && intEqualsAny(ops.I, 0, 1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"is"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// t = 0 and i % 10 = 1 and i % 100 != 11 or t != 0
if intEqualsAny(ops.T, 0) && intEqualsAny(ops.I%10, 1) && !intEqualsAny(ops.I%100, 11) ||
!intEqualsAny(ops.T, 0) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"mk"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// v = 0 and i % 10 = 1 or f % 10 = 1
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%10, 1) ||
intEqualsAny(ops.F%10, 1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"fil", "tl"}, &PluralSpec{
Plurals: newPluralSet(One, Other),
PluralFunc: func(ops *Operands) Plural {
// v = 0 and i = 1,2,3 or v = 0 and i % 10 != 4,6,9 or v != 0 and f % 10 != 4,6,9
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I, 1, 2, 3) ||
intEqualsAny(ops.V, 0) && !intEqualsAny(ops.I%10, 4, 6, 9) ||
!intEqualsAny(ops.V, 0) && !intEqualsAny(ops.F%10, 4, 6, 9) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"lv", "prg"}, &PluralSpec{
Plurals: newPluralSet(Zero, One, Other),
PluralFunc: func(ops *Operands) Plural {
// n % 10 = 0 or n % 100 = 11..19 or v = 2 and f % 100 = 11..19
if ops.NmodEqualsAny(10, 0) ||
ops.NmodInRange(100, 11, 19) ||
intEqualsAny(ops.V, 2) && intInRange(ops.F%100, 11, 19) {
return Zero
}
// n % 10 = 1 and n % 100 != 11 or v = 2 and f % 10 = 1 and f % 100 != 11 or v != 2 and f % 10 = 1
if ops.NmodEqualsAny(10, 1) && !ops.NmodEqualsAny(100, 11) ||
intEqualsAny(ops.V, 2) && intEqualsAny(ops.F%10, 1) && !intEqualsAny(ops.F%100, 11) ||
!intEqualsAny(ops.V, 2) && intEqualsAny(ops.F%10, 1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"lag"}, &PluralSpec{
Plurals: newPluralSet(Zero, One, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 0
if ops.NequalsAny(0) {
return Zero
}
// i = 0,1 and n != 0
if intEqualsAny(ops.I, 0, 1) && !ops.NequalsAny(0) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"ksh"}, &PluralSpec{
Plurals: newPluralSet(Zero, One, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 0
if ops.NequalsAny(0) {
return Zero
}
// n = 1
if ops.NequalsAny(1) {
return One
}
return Other
},
})
RegisterPluralSpec([]string{"iu", "kw", "naq", "se", "sma", "smi", "smj", "smn", "sms"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 1
if ops.NequalsAny(1) {
return One
}
// n = 2
if ops.NequalsAny(2) {
return Two
}
return Other
},
})
RegisterPluralSpec([]string{"shi"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 0 or n = 1
if intEqualsAny(ops.I, 0) ||
ops.NequalsAny(1) {
return One
}
// n = 2..10
if ops.NinRange(2, 10) {
return Few
}
return Other
},
})
RegisterPluralSpec([]string{"mo", "ro"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 1 and v = 0
if intEqualsAny(ops.I, 1) && intEqualsAny(ops.V, 0) {
return One
}
// v != 0 or n = 0 or n != 1 and n % 100 = 1..19
if !intEqualsAny(ops.V, 0) ||
ops.NequalsAny(0) ||
!ops.NequalsAny(1) && ops.NmodInRange(100, 1, 19) {
return Few
}
return Other
},
})
RegisterPluralSpec([]string{"bs", "hr", "sh", "sr"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Other),
PluralFunc: func(ops *Operands) Plural {
// v = 0 and i % 10 = 1 and i % 100 != 11 or f % 10 = 1 and f % 100 != 11
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%10, 1) && !intEqualsAny(ops.I%100, 11) ||
intEqualsAny(ops.F%10, 1) && !intEqualsAny(ops.F%100, 11) {
return One
}
// v = 0 and i % 10 = 2..4 and i % 100 != 12..14 or f % 10 = 2..4 and f % 100 != 12..14
if intEqualsAny(ops.V, 0) && intInRange(ops.I%10, 2, 4) && !intInRange(ops.I%100, 12, 14) ||
intInRange(ops.F%10, 2, 4) && !intInRange(ops.F%100, 12, 14) {
return Few
}
return Other
},
})
RegisterPluralSpec([]string{"gd"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Few, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 1,11
if ops.NequalsAny(1, 11) {
return One
}
// n = 2,12
if ops.NequalsAny(2, 12) {
return Two
}
// n = 3..10,13..19
if ops.NinRange(3, 10) || ops.NinRange(13, 19) {
return Few
}
return Other
},
})
RegisterPluralSpec([]string{"sl"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Few, Other),
PluralFunc: func(ops *Operands) Plural {
// v = 0 and i % 100 = 1
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%100, 1) {
return One
}
// v = 0 and i % 100 = 2
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%100, 2) {
return Two
}
// v = 0 and i % 100 = 3..4 or v != 0
if intEqualsAny(ops.V, 0) && intInRange(ops.I%100, 3, 4) ||
!intEqualsAny(ops.V, 0) {
return Few
}
return Other
},
})
RegisterPluralSpec([]string{"dsb", "hsb"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Few, Other),
PluralFunc: func(ops *Operands) Plural {
// v = 0 and i % 100 = 1 or f % 100 = 1
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%100, 1) ||
intEqualsAny(ops.F%100, 1) {
return One
}
// v = 0 and i % 100 = 2 or f % 100 = 2
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%100, 2) ||
intEqualsAny(ops.F%100, 2) {
return Two
}
// v = 0 and i % 100 = 3..4 or f % 100 = 3..4
if intEqualsAny(ops.V, 0) && intInRange(ops.I%100, 3, 4) ||
intInRange(ops.F%100, 3, 4) {
return Few
}
return Other
},
})
RegisterPluralSpec([]string{"he", "iw"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 1 and v = 0
if intEqualsAny(ops.I, 1) && intEqualsAny(ops.V, 0) {
return One
}
// i = 2 and v = 0
if intEqualsAny(ops.I, 2) && intEqualsAny(ops.V, 0) {
return Two
}
// v = 0 and n != 0..10 and n % 10 = 0
if intEqualsAny(ops.V, 0) && !ops.NinRange(0, 10) && ops.NmodEqualsAny(10, 0) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"cs", "sk"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 1 and v = 0
if intEqualsAny(ops.I, 1) && intEqualsAny(ops.V, 0) {
return One
}
// i = 2..4 and v = 0
if intInRange(ops.I, 2, 4) && intEqualsAny(ops.V, 0) {
return Few
}
// v != 0
if !intEqualsAny(ops.V, 0) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"pl"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// i = 1 and v = 0
if intEqualsAny(ops.I, 1) && intEqualsAny(ops.V, 0) {
return One
}
// v = 0 and i % 10 = 2..4 and i % 100 != 12..14
if intEqualsAny(ops.V, 0) && intInRange(ops.I%10, 2, 4) && !intInRange(ops.I%100, 12, 14) {
return Few
}
// v = 0 and i != 1 and i % 10 = 0..1 or v = 0 and i % 10 = 5..9 or v = 0 and i % 100 = 12..14
if intEqualsAny(ops.V, 0) && !intEqualsAny(ops.I, 1) && intInRange(ops.I%10, 0, 1) ||
intEqualsAny(ops.V, 0) && intInRange(ops.I%10, 5, 9) ||
intEqualsAny(ops.V, 0) && intInRange(ops.I%100, 12, 14) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"be"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// n % 10 = 1 and n % 100 != 11
if ops.NmodEqualsAny(10, 1) && !ops.NmodEqualsAny(100, 11) {
return One
}
// n % 10 = 2..4 and n % 100 != 12..14
if ops.NmodInRange(10, 2, 4) && !ops.NmodInRange(100, 12, 14) {
return Few
}
// n % 10 = 0 or n % 10 = 5..9 or n % 100 = 11..14
if ops.NmodEqualsAny(10, 0) ||
ops.NmodInRange(10, 5, 9) ||
ops.NmodInRange(100, 11, 14) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"lt"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// n % 10 = 1 and n % 100 != 11..19
if ops.NmodEqualsAny(10, 1) && !ops.NmodInRange(100, 11, 19) {
return One
}
// n % 10 = 2..9 and n % 100 != 11..19
if ops.NmodInRange(10, 2, 9) && !ops.NmodInRange(100, 11, 19) {
return Few
}
// f != 0
if !intEqualsAny(ops.F, 0) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"mt"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 1
if ops.NequalsAny(1) {
return One
}
// n = 0 or n % 100 = 2..10
if ops.NequalsAny(0) ||
ops.NmodInRange(100, 2, 10) {
return Few
}
// n % 100 = 11..19
if ops.NmodInRange(100, 11, 19) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"ru", "uk"}, &PluralSpec{
Plurals: newPluralSet(One, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// v = 0 and i % 10 = 1 and i % 100 != 11
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%10, 1) && !intEqualsAny(ops.I%100, 11) {
return One
}
// v = 0 and i % 10 = 2..4 and i % 100 != 12..14
if intEqualsAny(ops.V, 0) && intInRange(ops.I%10, 2, 4) && !intInRange(ops.I%100, 12, 14) {
return Few
}
// v = 0 and i % 10 = 0 or v = 0 and i % 10 = 5..9 or v = 0 and i % 100 = 11..14
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%10, 0) ||
intEqualsAny(ops.V, 0) && intInRange(ops.I%10, 5, 9) ||
intEqualsAny(ops.V, 0) && intInRange(ops.I%100, 11, 14) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"br"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// n % 10 = 1 and n % 100 != 11,71,91
if ops.NmodEqualsAny(10, 1) && !ops.NmodEqualsAny(100, 11, 71, 91) {
return One
}
// n % 10 = 2 and n % 100 != 12,72,92
if ops.NmodEqualsAny(10, 2) && !ops.NmodEqualsAny(100, 12, 72, 92) {
return Two
}
// n % 10 = 3..4,9 and n % 100 != 10..19,70..79,90..99
if (ops.NmodInRange(10, 3, 4) || ops.NmodEqualsAny(10, 9)) && !(ops.NmodInRange(100, 10, 19) || ops.NmodInRange(100, 70, 79) || ops.NmodInRange(100, 90, 99)) {
return Few
}
// n != 0 and n % 1000000 = 0
if !ops.NequalsAny(0) && ops.NmodEqualsAny(1000000, 0) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"ga"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 1
if ops.NequalsAny(1) {
return One
}
// n = 2
if ops.NequalsAny(2) {
return Two
}
// n = 3..6
if ops.NinRange(3, 6) {
return Few
}
// n = 7..10
if ops.NinRange(7, 10) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"gv"}, &PluralSpec{
Plurals: newPluralSet(One, Two, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// v = 0 and i % 10 = 1
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%10, 1) {
return One
}
// v = 0 and i % 10 = 2
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%10, 2) {
return Two
}
// v = 0 and i % 100 = 0,20,40,60,80
if intEqualsAny(ops.V, 0) && intEqualsAny(ops.I%100, 0, 20, 40, 60, 80) {
return Few
}
// v != 0
if !intEqualsAny(ops.V, 0) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"ar", "ars"}, &PluralSpec{
Plurals: newPluralSet(Zero, One, Two, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 0
if ops.NequalsAny(0) {
return Zero
}
// n = 1
if ops.NequalsAny(1) {
return One
}
// n = 2
if ops.NequalsAny(2) {
return Two
}
// n % 100 = 3..10
if ops.NmodInRange(100, 3, 10) {
return Few
}
// n % 100 = 11..99
if ops.NmodInRange(100, 11, 99) {
return Many
}
return Other
},
})
RegisterPluralSpec([]string{"cy"}, &PluralSpec{
Plurals: newPluralSet(Zero, One, Two, Few, Many, Other),
PluralFunc: func(ops *Operands) Plural {
// n = 0
if ops.NequalsAny(0) {
return Zero
}
// n = 1
if ops.NequalsAny(1) {
return One
}
// n = 2
if ops.NequalsAny(2) {
return Two
}
// n = 3
if ops.NequalsAny(3) {
return Few
}
// n = 6
if ops.NequalsAny(6) {
return Many
}
return Other
},
})
}

View File

@ -0,0 +1,82 @@
package translation
import (
"github.com/mattermost/go-i18n/i18n/language"
)
type pluralTranslation struct {
id string
templates map[language.Plural]*template
}
func (pt *pluralTranslation) MarshalInterface() interface{} {
return map[string]interface{}{
"id": pt.id,
"translation": pt.templates,
}
}
func (pt *pluralTranslation) MarshalFlatInterface() interface{} {
return pt.templates
}
func (pt *pluralTranslation) ID() string {
return pt.id
}
func (pt *pluralTranslation) Template(pc language.Plural) *template {
return pt.templates[pc]
}
func (pt *pluralTranslation) UntranslatedCopy() Translation {
return &pluralTranslation{pt.id, make(map[language.Plural]*template)}
}
func (pt *pluralTranslation) Normalize(l *language.Language) Translation {
// Delete plural categories that don't belong to this language.
for pc := range pt.templates {
if _, ok := l.Plurals[pc]; !ok {
delete(pt.templates, pc)
}
}
// Create map entries for missing valid categories.
for pc := range l.Plurals {
if _, ok := pt.templates[pc]; !ok {
pt.templates[pc] = mustNewTemplate("")
}
}
return pt
}
func (pt *pluralTranslation) Backfill(src Translation) Translation {
for pc, t := range pt.templates {
if (t == nil || t.src == "") && src != nil {
pt.templates[pc] = src.Template(language.Other)
}
}
return pt
}
func (pt *pluralTranslation) Merge(t Translation) Translation {
other, ok := t.(*pluralTranslation)
if !ok || pt.ID() != t.ID() {
return t
}
for pluralCategory, template := range other.templates {
if template != nil && template.src != "" {
pt.templates[pluralCategory] = template
}
}
return pt
}
func (pt *pluralTranslation) Incomplete(l *language.Language) bool {
for pc := range l.Plurals {
if t := pt.templates[pc]; t == nil || t.src == "" {
return true
}
}
return false
}
var _ = Translation(&pluralTranslation{})

View File

@ -0,0 +1,61 @@
package translation
import (
"github.com/mattermost/go-i18n/i18n/language"
)
type singleTranslation struct {
id string
template *template
}
func (st *singleTranslation) MarshalInterface() interface{} {
return map[string]interface{}{
"id": st.id,
"translation": st.template,
}
}
func (st *singleTranslation) MarshalFlatInterface() interface{} {
return map[string]interface{}{"other": st.template}
}
func (st *singleTranslation) ID() string {
return st.id
}
func (st *singleTranslation) Template(pc language.Plural) *template {
return st.template
}
func (st *singleTranslation) UntranslatedCopy() Translation {
return &singleTranslation{st.id, mustNewTemplate("")}
}
func (st *singleTranslation) Normalize(language *language.Language) Translation {
return st
}
func (st *singleTranslation) Backfill(src Translation) Translation {
if (st.template == nil || st.template.src == "") && src != nil {
st.template = src.Template(language.Other)
}
return st
}
func (st *singleTranslation) Merge(t Translation) Translation {
other, ok := t.(*singleTranslation)
if !ok || st.ID() != t.ID() {
return t
}
if other.template != nil && other.template.src != "" {
st.template = other.template
}
return st
}
func (st *singleTranslation) Incomplete(l *language.Language) bool {
return st.template == nil || st.template.src == ""
}
var _ = Translation(&singleTranslation{})

View File

@ -0,0 +1,65 @@
package translation
import (
"bytes"
"encoding"
"strings"
gotemplate "text/template"
)
type template struct {
tmpl *gotemplate.Template
src string
}
func newTemplate(src string) (*template, error) {
if src == "" {
return new(template), nil
}
var tmpl template
err := tmpl.parseTemplate(src)
return &tmpl, err
}
func mustNewTemplate(src string) *template {
t, err := newTemplate(src)
if err != nil {
panic(err)
}
return t
}
func (t *template) String() string {
return t.src
}
func (t *template) Execute(args interface{}) string {
if t.tmpl == nil {
return t.src
}
var buf bytes.Buffer
if err := t.tmpl.Execute(&buf, args); err != nil {
return err.Error()
}
return buf.String()
}
func (t *template) MarshalText() ([]byte, error) {
return []byte(t.src), nil
}
func (t *template) UnmarshalText(src []byte) error {
return t.parseTemplate(string(src))
}
func (t *template) parseTemplate(src string) (err error) {
t.src = src
if strings.Contains(src, "{{") {
t.tmpl, err = gotemplate.New(src).Parse(src)
}
return
}
var _ = encoding.TextMarshaler(&template{})
var _ = encoding.TextUnmarshaler(&template{})

View File

@ -0,0 +1,84 @@
// Package translation defines the interface for a translation.
package translation
import (
"fmt"
"github.com/mattermost/go-i18n/i18n/language"
)
// Translation is the interface that represents a translated string.
type Translation interface {
// MarshalInterface returns the object that should be used
// to serialize the translation.
MarshalInterface() interface{}
MarshalFlatInterface() interface{}
ID() string
Template(language.Plural) *template
UntranslatedCopy() Translation
Normalize(language *language.Language) Translation
Backfill(src Translation) Translation
Merge(Translation) Translation
Incomplete(l *language.Language) bool
}
// SortableByID implements sort.Interface for a slice of translations.
type SortableByID []Translation
func (a SortableByID) Len() int { return len(a) }
func (a SortableByID) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a SortableByID) Less(i, j int) bool { return a[i].ID() < a[j].ID() }
// NewTranslation reflects on data to create a new Translation.
//
// data["id"] must be a string and data["translation"] must be either a string
// for a non-plural translation or a map[string]interface{} for a plural translation.
func NewTranslation(data map[string]interface{}) (Translation, error) {
id, ok := data["id"].(string)
if !ok {
return nil, fmt.Errorf(`missing "id" key`)
}
var pluralObject map[string]interface{}
switch translation := data["translation"].(type) {
case string:
tmpl, err := newTemplate(translation)
if err != nil {
return nil, err
}
return &singleTranslation{id, tmpl}, nil
case map[interface{}]interface{}:
// The YAML parser uses interface{} keys so we first convert them to string keys.
pluralObject = make(map[string]interface{})
for k, v := range translation {
kstr, ok := k.(string)
if !ok {
return nil, fmt.Errorf(`invalid plural category type %T; expected string`, k)
}
pluralObject[kstr] = v
}
case map[string]interface{}:
pluralObject = translation
case nil:
return nil, fmt.Errorf(`missing "translation" key`)
default:
return nil, fmt.Errorf(`unsupported type for "translation" key %T`, translation)
}
templates := make(map[language.Plural]*template, len(pluralObject))
for k, v := range pluralObject {
pc, err := language.NewPlural(k)
if err != nil {
return nil, err
}
str, ok := v.(string)
if !ok {
return nil, fmt.Errorf(`plural category "%s" has value of type %T; expected string`, pc, v)
}
tmpl, err := newTemplate(str)
if err != nil {
return nil, err
}
templates[pc] = tmpl
}
return &pluralTranslation{id, templates}, nil
}

1
vendor/github.com/mattermost/gosaml2/.gitignore generated vendored Normal file
View File

@ -0,0 +1 @@
*.test

12
vendor/github.com/mattermost/gosaml2/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,12 @@
language: go
go:
- 1.17
- 1.16
- 1.15
- 1.14
- tip
matrix:
allow_failures:
- go: tip

175
vendor/github.com/mattermost/gosaml2/LICENSE generated vendored Normal file
View File

@ -0,0 +1,175 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.

34
vendor/github.com/mattermost/gosaml2/README.md generated vendored Normal file
View File

@ -0,0 +1,34 @@
# gosaml2
[![Build Status](https://github.com/mattermost/gosaml2/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/mattermost/gosaml2/actions/workflows/test.yml?query=branch%3Amain)
[![GoDoc](https://godoc.org/github.com/mattermost/gosaml2?status.svg)](https://godoc.org/github.com/mattermost/gosaml2)
SAML 2.0 implemementation for Service Providers based on [etree](https://github.com/beevik/etree)
and [goxmldsig](https://github.com/russellhaering/goxmldsig), a pure Go
implementation of XML digital signatures.
## Installation
Install `gosaml2` into your `$GOPATH` using `go get`:
```
go get github.com/mattermost/gosaml2
```
## Example
See [demo.go](s2example/demo.go).
## Supported Identity Providers
This library is meant to be a generic SAML implementation. If you find a
standards compliant identity provider that it doesn't work with please
submit a bug or pull request.
The following identity providers have been tested:
* Okta
* Auth0
* Shibboleth
* Ipsilon
* OneLogin

66
vendor/github.com/mattermost/gosaml2/attribute.go generated vendored Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import "github.com/mattermost/gosaml2/types"
// Values is a convenience wrapper for a map of strings to Attributes, which
// can be used for easy access to the string values of Attribute lists.
type Values map[string]types.Attribute
// Get is a safe method (nil maps will not panic) for returning the first value
// for an attribute at a key, or the empty string if none exists.
func (vals Values) Get(k string) string {
if vals == nil {
return ""
}
if v, ok := vals[k]; ok && len(v.Values) > 0 {
return string(v.Values[0].Value)
}
return ""
}
//GetSize returns the number of values for an attribute at a key.
//Returns '0' in case of error or if key is not found.
func (vals Values) GetSize(k string) int {
if vals == nil {
return 0
}
v, ok := vals[k]
if ok {
return len(v.Values)
}
return 0
}
//GetAll returns all the values for an attribute at a key.
//Returns an empty slice in case of error of if key is not found.
func (vals Values) GetAll(k string) []string {
var av []string
if vals == nil {
return av
}
if v, ok := vals[k]; ok && len(v.Values) > 0 {
for i := 0; i < len(v.Values); i++ {
av = append(av, string(v.Values[i].Value))
}
}
return av
}

30
vendor/github.com/mattermost/gosaml2/authn_request.go generated vendored Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import "time"
// AuthNRequest is the go struct representation of an authentication request
type AuthNRequest struct {
ID string `xml:",attr"`
Version string `xml:",attr"`
ProtocolBinding string `xml:",attr"`
AssertionConsumerServiceURL string `xml:",attr"`
IssueInstant time.Time `xml:",attr"`
Destination string `xml:",attr"`
Issuer string
}

View File

@ -0,0 +1,158 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import (
"bytes"
"encoding/base64"
"html/template"
"github.com/beevik/etree"
"github.com/mattermost/gosaml2/uuid"
)
func (sp *SAMLServiceProvider) buildLogoutResponse(statusCodeValue string, reqID string, includeSig bool) (*etree.Document, error) {
logoutResponse := &etree.Element{
Space: "samlp",
Tag: "LogoutResponse",
}
logoutResponse.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:2.0:protocol")
logoutResponse.CreateAttr("xmlns:saml", "urn:oasis:names:tc:SAML:2.0:assertion")
arId := uuid.NewV4()
logoutResponse.CreateAttr("ID", "_"+arId.String())
logoutResponse.CreateAttr("Version", "2.0")
logoutResponse.CreateAttr("IssueInstant", sp.Clock.Now().UTC().Format(issueInstantFormat))
logoutResponse.CreateAttr("Destination", sp.IdentityProviderSLOURL)
logoutResponse.CreateAttr("InResponseTo", reqID)
// NOTE(russell_h): In earlier versions we mistakenly sent the IdentityProviderIssuer
// in the AuthnRequest. For backwards compatibility we will fall back to that
// behavior when ServiceProviderIssuer isn't set.
if sp.ServiceProviderIssuer != "" {
logoutResponse.CreateElement("saml:Issuer").SetText(sp.ServiceProviderIssuer)
} else {
logoutResponse.CreateElement("saml:Issuer").SetText(sp.IdentityProviderIssuer)
}
status := logoutResponse.CreateElement("samlp:Status")
statusCode := status.CreateElement("samlp:StatusCode")
statusCode.CreateAttr("Value", statusCodeValue)
doc := etree.NewDocument()
// Only POST binding includes <Signature> in <AuthnRequest> (includeSig)
if includeSig {
signed, err := sp.SignLogoutResponse(logoutResponse)
if err != nil {
return nil, err
}
doc.SetRoot(signed)
} else {
doc.SetRoot(logoutResponse)
}
return doc, nil
}
func (sp *SAMLServiceProvider) BuildLogoutResponseDocument(status string, reqID string) (*etree.Document, error) {
return sp.buildLogoutResponse(status, reqID, true)
}
func (sp *SAMLServiceProvider) BuildLogoutResponseDocumentNoSig(status string, reqID string) (*etree.Document, error) {
return sp.buildLogoutResponse(status, reqID, false)
}
func (sp *SAMLServiceProvider) SignLogoutResponse(el *etree.Element) (*etree.Element, error) {
ctx := sp.SigningContext()
sig, err := ctx.ConstructSignature(el, true)
if err != nil {
return nil, err
}
ret := el.Copy()
var children []etree.Token
children = append(children, ret.Child[0]) // issuer is always first
children = append(children, sig) // next is the signature
children = append(children, ret.Child[1:]...) // then all other children
ret.Child = children
return ret, nil
}
func (sp *SAMLServiceProvider) buildLogoutResponseBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) {
respBuf, err := doc.WriteToBytes()
if err != nil {
return nil, err
}
encodedRespBuf := base64.StdEncoding.EncodeToString(respBuf)
var tmpl *template.Template
var rv bytes.Buffer
if relayState != "" {
tmpl = template.Must(template.New("saml-post-form").Parse(`<html>` +
`<form method="post" action="{{.URL}}" id="SAMLResponseForm">` +
`<input type="hidden" name="SAMLResponse" value="{{.SAMLResponse}}" />` +
`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
`<input id="SAMLSubmitButton" type="submit" value="Continue" />` +
`</form>` +
`<script>document.getElementById('SAMLSubmitButton').style.visibility='hidden';</script>` +
`<script>document.getElementById('SAMLResponseForm').submit();</script>` +
`</html>`))
data := struct {
URL string
SAMLResponse string
RelayState string
}{
URL: sp.IdentityProviderSLOURL,
SAMLResponse: encodedRespBuf,
RelayState: relayState,
}
if err = tmpl.Execute(&rv, data); err != nil {
return nil, err
}
} else {
tmpl = template.Must(template.New("saml-post-form").Parse(`<html>` +
`<form method="post" action="{{.URL}}" id="SAMLResponseForm">` +
`<input type="hidden" name="SAMLResponse" value="{{.SAMLResponse}}" />` +
`<input id="SAMLSubmitButton" type="submit" value="Continue" />` +
`</form>` +
`<script>document.getElementById('SAMLSubmitButton').style.visibility='hidden';</script>` +
`<script>document.getElementById('SAMLResponseForm').submit();</script>` +
`</html>`))
data := struct {
URL string
SAMLResponse string
}{
URL: sp.IdentityProviderSLOURL,
SAMLResponse: encodedRespBuf,
}
if err = tmpl.Execute(&rv, data); err != nil {
return nil, err
}
}
return rv.Bytes(), nil
}
func (sp *SAMLServiceProvider) BuildLogoutResponseBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) {
return sp.buildLogoutResponseBodyPostFromDocument(relayState, doc)
}

559
vendor/github.com/mattermost/gosaml2/build_request.go generated vendored Normal file
View File

@ -0,0 +1,559 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import (
"bytes"
"compress/flate"
"encoding/base64"
"fmt"
"html/template"
"net/http"
"net/url"
"github.com/beevik/etree"
"github.com/mattermost/gosaml2/uuid"
)
const issueInstantFormat = "2006-01-02T15:04:05Z"
func (sp *SAMLServiceProvider) buildAuthnRequest(includeSig bool) (*etree.Document, error) {
authnRequest := &etree.Element{
Space: "samlp",
Tag: "AuthnRequest",
}
authnRequest.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:2.0:protocol")
authnRequest.CreateAttr("xmlns:saml", "urn:oasis:names:tc:SAML:2.0:assertion")
arId := uuid.NewV4()
authnRequest.CreateAttr("ID", "_"+arId.String())
authnRequest.CreateAttr("Version", "2.0")
authnRequest.CreateAttr("ProtocolBinding", "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST")
authnRequest.CreateAttr("AssertionConsumerServiceURL", sp.AssertionConsumerServiceURL)
authnRequest.CreateAttr("IssueInstant", sp.Clock.Now().UTC().Format(issueInstantFormat))
authnRequest.CreateAttr("Destination", sp.IdentityProviderSSOURL)
// NOTE(russell_h): In earlier versions we mistakenly sent the IdentityProviderIssuer
// in the AuthnRequest. For backwards compatibility we will fall back to that
// behavior when ServiceProviderIssuer isn't set.
if sp.ServiceProviderIssuer != "" {
authnRequest.CreateElement("saml:Issuer").SetText(sp.ServiceProviderIssuer)
} else {
authnRequest.CreateElement("saml:Issuer").SetText(sp.IdentityProviderIssuer)
}
nameIdPolicy := authnRequest.CreateElement("samlp:NameIDPolicy")
nameIdPolicy.CreateAttr("AllowCreate", "true")
if sp.NameIdFormat != "" {
nameIdPolicy.CreateAttr("Format", sp.NameIdFormat)
}
if sp.RequestedAuthnContext != nil {
requestedAuthnContext := authnRequest.CreateElement("samlp:RequestedAuthnContext")
requestedAuthnContext.CreateAttr("Comparison", sp.RequestedAuthnContext.Comparison)
for _, context := range sp.RequestedAuthnContext.Contexts {
authnContextClassRef := requestedAuthnContext.CreateElement("saml:AuthnContextClassRef")
authnContextClassRef.SetText(context)
}
}
if sp.ScopingIDPProviderId != "" && sp.ScopingIDPProviderName != "" {
scoping := authnRequest.CreateElement("samlp:Scoping")
idpList := scoping.CreateElement("samlp:IDPList")
idpEntry := idpList.CreateElement("samlp:IDPEntry")
idpEntry.CreateAttr("ProviderID", sp.ScopingIDPProviderId)
idpEntry.CreateAttr("Name", sp.ScopingIDPProviderName)
}
doc := etree.NewDocument()
// Only POST binding includes <Signature> in <AuthnRequest> (includeSig)
if sp.SignAuthnRequests && includeSig {
signed, err := sp.SignAuthnRequest(authnRequest)
if err != nil {
return nil, err
}
doc.SetRoot(signed)
} else {
doc.SetRoot(authnRequest)
}
return doc, nil
}
func (sp *SAMLServiceProvider) BuildAuthRequestDocument() (*etree.Document, error) {
return sp.buildAuthnRequest(true)
}
func (sp *SAMLServiceProvider) BuildAuthRequestDocumentNoSig() (*etree.Document, error) {
return sp.buildAuthnRequest(false)
}
// SignAuthnRequest takes a document, builds a signature, creates another document
// and inserts the signature in it. According to the schema, the position of the
// signature is right after the Issuer [1] then all other children.
//
// [1] https://docs.oasis-open.org/security/saml/v2.0/saml-schema-protocol-2.0.xsd
func (sp *SAMLServiceProvider) SignAuthnRequest(el *etree.Element) (*etree.Element, error) {
ctx := sp.SigningContext()
sig, err := ctx.ConstructSignature(el, true)
if err != nil {
return nil, err
}
ret := el.Copy()
var children []etree.Token
children = append(children, ret.Child[0]) // issuer is always first
children = append(children, sig) // next is the signature
children = append(children, ret.Child[1:]...) // then all other children
ret.Child = children
return ret, nil
}
// BuildAuthRequest builds <AuthnRequest> for identity provider
func (sp *SAMLServiceProvider) BuildAuthRequest() (string, error) {
doc, err := sp.BuildAuthRequestDocument()
if err != nil {
return "", err
}
return doc.WriteToString()
}
func (sp *SAMLServiceProvider) buildAuthURLFromDocument(relayState, binding string, doc *etree.Document) (string, error) {
parsedUrl, err := url.Parse(sp.IdentityProviderSSOURL)
if err != nil {
return "", err
}
authnRequest, err := doc.WriteToString()
if err != nil {
return "", err
}
buf := &bytes.Buffer{}
fw, err := flate.NewWriter(buf, flate.DefaultCompression)
if err != nil {
return "", fmt.Errorf("flate NewWriter error: %v", err)
}
_, err = fw.Write([]byte(authnRequest))
if err != nil {
return "", fmt.Errorf("flate.Writer Write error: %v", err)
}
err = fw.Close()
if err != nil {
return "", fmt.Errorf("flate.Writer Close error: %v", err)
}
qs := parsedUrl.Query()
qs.Add("SAMLRequest", base64.StdEncoding.EncodeToString(buf.Bytes()))
if relayState != "" {
qs.Add("RelayState", relayState)
}
if sp.SignAuthnRequests && binding == BindingHttpRedirect {
// Sign URL encoded query (see Section 3.4.4.1 DEFLATE Encoding of saml-bindings-2.0-os.pdf)
ctx := sp.SigningContext()
qs.Add("SigAlg", ctx.GetSignatureMethodIdentifier())
var rawSignature []byte
if rawSignature, err = ctx.SignString(signatureInputString(qs.Get("SAMLRequest"), qs.Get("RelayState"), qs.Get("SigAlg"))); err != nil {
return "", fmt.Errorf("unable to sign query string of redirect URL: %v", err)
}
// Now add base64 encoded Signature
qs.Add("Signature", base64.StdEncoding.EncodeToString(rawSignature))
}
//Here the parameters may appear in any order.
parsedUrl.RawQuery = qs.Encode()
return parsedUrl.String(), nil
}
func (sp *SAMLServiceProvider) BuildAuthURLFromDocument(relayState string, doc *etree.Document) (string, error) {
return sp.buildAuthURLFromDocument(relayState, BindingHttpPost, doc)
}
func (sp *SAMLServiceProvider) BuildAuthURLRedirect(relayState string, doc *etree.Document) (string, error) {
return sp.buildAuthURLFromDocument(relayState, BindingHttpRedirect, doc)
}
func (sp *SAMLServiceProvider) buildAuthBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) {
reqBuf, err := doc.WriteToBytes()
if err != nil {
return nil, err
}
encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
var tmpl *template.Template
var rv bytes.Buffer
if relayState != "" {
tmpl = template.Must(template.New("saml-post-form").Parse(`` +
`<form method="POST" action="{{.URL}}" id="SAMLRequestForm">` +
`<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
`<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
`</form>` +
`<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
`document.getElementById('SAMLRequestForm').submit();</script>`))
data := struct {
URL string
SAMLRequest string
RelayState string
}{
URL: sp.IdentityProviderSSOURL,
SAMLRequest: encodedReqBuf,
RelayState: relayState,
}
if err = tmpl.Execute(&rv, data); err != nil {
return nil, err
}
} else {
tmpl = template.Must(template.New("saml-post-form").Parse(`` +
`<form method="POST" action="{{.URL}}" id="SAMLRequestForm">` +
`<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
`<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
`</form>` +
`<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
`document.getElementById('SAMLRequestForm').submit();</script>`))
data := struct {
URL string
SAMLRequest string
}{
URL: sp.IdentityProviderSSOURL,
SAMLRequest: encodedReqBuf,
}
if err = tmpl.Execute(&rv, data); err != nil {
return nil, err
}
}
return rv.Bytes(), nil
}
//BuildAuthBodyPost builds the POST body to be sent to IDP.
func (sp *SAMLServiceProvider) BuildAuthBodyPost(relayState string) ([]byte, error) {
var doc *etree.Document
var err error
if sp.SignAuthnRequests {
doc, err = sp.BuildAuthRequestDocument()
} else {
doc, err = sp.BuildAuthRequestDocumentNoSig()
}
if err != nil {
return nil, err
}
return sp.buildAuthBodyPostFromDocument(relayState, doc)
}
//BuildAuthBodyPostFromDocument builds the POST body to be sent to IDP.
//It takes the AuthnRequest xml as input.
func (sp *SAMLServiceProvider) BuildAuthBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) {
return sp.buildAuthBodyPostFromDocument(relayState, doc)
}
// BuildAuthURL builds redirect URL to be sent to principal
func (sp *SAMLServiceProvider) BuildAuthURL(relayState string) (string, error) {
doc, err := sp.BuildAuthRequestDocument()
if err != nil {
return "", err
}
return sp.BuildAuthURLFromDocument(relayState, doc)
}
// AuthRedirect takes a ResponseWriter and Request from an http interaction and
// redirects to the SAMLServiceProvider's configured IdP, including the
// relayState provided, if any.
func (sp *SAMLServiceProvider) AuthRedirect(w http.ResponseWriter, r *http.Request, relayState string) (err error) {
url, err := sp.BuildAuthURL(relayState)
if err != nil {
return err
}
http.Redirect(w, r, url, http.StatusFound)
return nil
}
func (sp *SAMLServiceProvider) buildLogoutRequest(includeSig bool, nameID string, sessionIndex string) (*etree.Document, error) {
logoutRequest := &etree.Element{
Space: "samlp",
Tag: "LogoutRequest",
}
logoutRequest.CreateAttr("xmlns:samlp", "urn:oasis:names:tc:SAML:2.0:protocol")
logoutRequest.CreateAttr("xmlns:saml", "urn:oasis:names:tc:SAML:2.0:assertion")
arId := uuid.NewV4()
logoutRequest.CreateAttr("ID", "_"+arId.String())
logoutRequest.CreateAttr("Version", "2.0")
logoutRequest.CreateAttr("IssueInstant", sp.Clock.Now().UTC().Format(issueInstantFormat))
logoutRequest.CreateAttr("Destination", sp.IdentityProviderSLOURL)
// NOTE(russell_h): In earlier versions we mistakenly sent the IdentityProviderIssuer
// in the AuthnRequest. For backwards compatibility we will fall back to that
// behavior when ServiceProviderIssuer isn't set.
// TODO: Throw error in case Issuer is empty.
if sp.ServiceProviderIssuer != "" {
logoutRequest.CreateElement("saml:Issuer").SetText(sp.ServiceProviderIssuer)
} else {
logoutRequest.CreateElement("saml:Issuer").SetText(sp.IdentityProviderIssuer)
}
nameId := logoutRequest.CreateElement("saml:NameID")
nameId.SetText(nameID)
nameId.CreateAttr("Format", sp.NameIdFormat)
//Section 3.7.1 - http://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf says
//SessionIndex is optional. If the IDP supports SLO then it must send SessionIndex as per
//Section 4.1.4.2 of https://docs.oasis-open.org/security/saml/v2.0/saml-profiles-2.0-os.pdf.
//As per section 4.4.3.1 of //docs.oasis-open.org/security/saml/v2.0/saml-profiles-2.0-os.pdf,
//a LogoutRequest issued by Session Participant to Identity Provider, must contain
//at least one SessionIndex element needs to be included.
nameId = logoutRequest.CreateElement("samlp:SessionIndex")
nameId.SetText(sessionIndex)
doc := etree.NewDocument()
if includeSig {
signed, err := sp.SignLogoutRequest(logoutRequest)
if err != nil {
return nil, err
}
doc.SetRoot(signed)
} else {
doc.SetRoot(logoutRequest)
}
return doc, nil
}
func (sp *SAMLServiceProvider) SignLogoutRequest(el *etree.Element) (*etree.Element, error) {
ctx := sp.SigningContext()
sig, err := ctx.ConstructSignature(el, true)
if err != nil {
return nil, err
}
ret := el.Copy()
var children []etree.Token
children = append(children, ret.Child[0]) // issuer is always first
children = append(children, sig) // next is the signature
children = append(children, ret.Child[1:]...) // then all other children
ret.Child = children
return ret, nil
}
func (sp *SAMLServiceProvider) BuildLogoutRequestDocumentNoSig(nameID string, sessionIndex string) (*etree.Document, error) {
return sp.buildLogoutRequest(false, nameID, sessionIndex)
}
func (sp *SAMLServiceProvider) BuildLogoutRequestDocument(nameID string, sessionIndex string) (*etree.Document, error) {
return sp.buildLogoutRequest(true, nameID, sessionIndex)
}
//BuildLogoutBodyPostFromDocument builds the POST body to be sent to IDP.
//It takes the LogoutRequest xml as input.
func (sp *SAMLServiceProvider) BuildLogoutBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) {
return sp.buildLogoutBodyPostFromDocument(relayState, doc)
}
func (sp *SAMLServiceProvider) buildLogoutBodyPostFromDocument(relayState string, doc *etree.Document) ([]byte, error) {
reqBuf, err := doc.WriteToBytes()
if err != nil {
return nil, err
}
encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
var tmpl *template.Template
var rv bytes.Buffer
if relayState != "" {
tmpl = template.Must(template.New("saml-post-form").Parse(`` +
`<form method="POST" action="{{.URL}}" id="SAMLRequestForm">` +
`<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
`<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
`</form>` +
`<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
`document.getElementById('SAMLRequestForm').submit();</script>`))
data := struct {
URL string
SAMLRequest string
RelayState string
}{
URL: sp.IdentityProviderSLOURL,
SAMLRequest: encodedReqBuf,
RelayState: relayState,
}
if err = tmpl.Execute(&rv, data); err != nil {
return nil, err
}
} else {
tmpl = template.Must(template.New("saml-post-form").Parse(`` +
`<form method="POST" action="{{.URL}}" id="SAMLRequestForm">` +
`<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
`<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
`</form>` +
`<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
`document.getElementById('SAMLRequestForm').submit();</script>`))
data := struct {
URL string
SAMLRequest string
}{
URL: sp.IdentityProviderSLOURL,
SAMLRequest: encodedReqBuf,
}
if err = tmpl.Execute(&rv, data); err != nil {
return nil, err
}
}
return rv.Bytes(), nil
}
func (sp *SAMLServiceProvider) BuildLogoutURLRedirect(relayState string, doc *etree.Document) (string, error) {
return sp.buildLogoutURLFromDocument(relayState, BindingHttpRedirect, doc)
}
func (sp *SAMLServiceProvider) buildLogoutURLFromDocument(relayState, binding string, doc *etree.Document) (string, error) {
parsedUrl, err := url.Parse(sp.IdentityProviderSLOURL)
if err != nil {
return "", err
}
logoutRequest, err := doc.WriteToString()
if err != nil {
return "", err
}
buf := &bytes.Buffer{}
fw, err := flate.NewWriter(buf, flate.DefaultCompression)
if err != nil {
return "", fmt.Errorf("flate NewWriter error: %v", err)
}
_, err = fw.Write([]byte(logoutRequest))
if err != nil {
return "", fmt.Errorf("flate.Writer Write error: %v", err)
}
err = fw.Close()
if err != nil {
return "", fmt.Errorf("flate.Writer Close error: %v", err)
}
qs := parsedUrl.Query()
qs.Add("SAMLRequest", base64.StdEncoding.EncodeToString(buf.Bytes()))
if relayState != "" {
qs.Add("RelayState", relayState)
}
if binding == BindingHttpRedirect {
// Sign URL encoded query (see Section 3.4.4.1 DEFLATE Encoding of saml-bindings-2.0-os.pdf)
ctx := sp.SigningContext()
qs.Add("SigAlg", ctx.GetSignatureMethodIdentifier())
var rawSignature []byte
//qs.Encode() sorts the keys (See https://golang.org/pkg/net/url/#Values.Encode).
//If RelayState parameter is present then RelayState parameter
//will be put first by Encode(). Hence encode them separately and concatenate.
//Signature string has to have parameters in the order - SAMLRequest=value&RelayState=value&SigAlg=value.
//(See Section 3.4.4.1 saml-bindings-2.0-os.pdf).
var orderedParams = []string{"SAMLRequest", "RelayState", "SigAlg"}
var paramValueMap = make(map[string]string)
paramValueMap["SAMLRequest"] = base64.StdEncoding.EncodeToString(buf.Bytes())
if relayState != "" {
paramValueMap["RelayState"] = relayState
}
paramValueMap["SigAlg"] = ctx.GetSignatureMethodIdentifier()
ss := ""
for _, k := range orderedParams {
v, ok := paramValueMap[k]
if ok {
//Add the value after URL encoding.
u := url.Values{}
u.Add(k, v)
e := u.Encode()
if ss != "" {
ss += "&" + e
} else {
ss = e
}
}
}
//Now generate the signature on the string of ordered parameters.
if rawSignature, err = ctx.SignString(ss); err != nil {
return "", fmt.Errorf("unable to sign query string of redirect URL: %v", err)
}
// Now add base64 encoded Signature
qs.Add("Signature", base64.StdEncoding.EncodeToString(rawSignature))
}
//Here the parameters may appear in any order.
parsedUrl.RawQuery = qs.Encode()
return parsedUrl.String(), nil
}
// signatureInputString constructs the string to be fed into the signature algorithm, as described
// in section 3.4.4.1 of
// https://www.oasis-open.org/committees/download.php/56779/sstc-saml-bindings-errata-2.0-wd-06.pdf
func signatureInputString(samlRequest, relayState, sigAlg string) string {
var params [][2]string
if relayState == "" {
params = [][2]string{{"SAMLRequest", samlRequest}, {"SigAlg", sigAlg}}
} else {
params = [][2]string{{"SAMLRequest", samlRequest}, {"RelayState", relayState}, {"SigAlg", sigAlg}}
}
var buf bytes.Buffer
for _, kv := range params {
k, v := kv[0], kv[1]
if buf.Len() > 0 {
buf.WriteByte('&')
}
buf.WriteString(url.QueryEscape(k) + "=" + url.QueryEscape(v))
}
return buf.String()
}

View File

@ -0,0 +1,85 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import (
"encoding/base64"
"fmt"
dsig "github.com/russellhaering/goxmldsig"
)
func (sp *SAMLServiceProvider) validateLogoutRequestAttributes(request *LogoutRequest) error {
if request.Destination != "" && request.Destination != sp.ServiceProviderSLOURL {
return ErrInvalidValue{
Key: DestinationAttr,
Expected: sp.ServiceProviderSLOURL,
Actual: request.Destination,
}
}
if request.Version != "2.0" {
return ErrInvalidValue{
Reason: ReasonUnsupported,
Key: "SAML version",
Expected: "2.0",
Actual: request.Version,
}
}
return nil
}
func (sp *SAMLServiceProvider) ValidateEncodedLogoutRequestPOST(encodedRequest string) (*LogoutRequest, error) {
raw, err := base64.StdEncoding.DecodeString(encodedRequest)
if err != nil {
return nil, err
}
// Parse the raw request - parseResponse is generic
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
var requestSignatureValidated bool
if !sp.SkipSignatureValidation {
el, err = sp.validateElementSignature(el)
if err == dsig.ErrMissingSignature {
// Unfortunately we just blew away our Response
el = doc.Root()
} else if err != nil {
return nil, err
} else if el == nil {
return nil, fmt.Errorf("missing transformed logout request")
} else {
requestSignatureValidated = true
}
}
decodedRequest := &LogoutRequest{}
err = xmlUnmarshalElement(el, decodedRequest)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal logout request: %v", err)
}
decodedRequest.SignatureValidated = requestSignatureValidated
err = sp.ValidateDecodedLogoutRequest(decodedRequest)
if err != nil {
return nil, err
}
return decodedRequest, nil
}

478
vendor/github.com/mattermost/gosaml2/decode_response.go generated vendored Normal file
View File

@ -0,0 +1,478 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import (
"bytes"
"compress/flate"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io"
"encoding/xml"
"github.com/beevik/etree"
"github.com/mattermost/gosaml2/types"
rtvalidator "github.com/mattermost/xml-roundtrip-validator"
dsig "github.com/russellhaering/goxmldsig"
"github.com/russellhaering/goxmldsig/etreeutils"
)
const (
defaultMaxDecompressedResponseSize = 5 * 1024 * 1024
)
func (sp *SAMLServiceProvider) validationContext() *dsig.ValidationContext {
ctx := dsig.NewDefaultValidationContext(sp.IDPCertificateStore)
ctx.Clock = sp.Clock
return ctx
}
// validateResponseAttributes validates a SAML Response's tag and attributes. It does
// not inspect child elements of the Response at all.
func (sp *SAMLServiceProvider) validateResponseAttributes(response *types.Response) error {
if response.Destination != "" && response.Destination != sp.AssertionConsumerServiceURL {
return ErrInvalidValue{
Key: DestinationAttr,
Expected: sp.AssertionConsumerServiceURL,
Actual: response.Destination,
}
}
if response.Version != "2.0" {
return ErrInvalidValue{
Reason: ReasonUnsupported,
Key: "SAML version",
Expected: "2.0",
Actual: response.Version,
}
}
return nil
}
// validateLogoutResponseAttributes validates a SAML Response's tag and attributes. It does
// not inspect child elements of the Response at all.
func (sp *SAMLServiceProvider) validateLogoutResponseAttributes(response *types.LogoutResponse) error {
if response.Destination != "" && response.Destination != sp.ServiceProviderSLOURL {
return ErrInvalidValue{
Key: DestinationAttr,
Expected: sp.ServiceProviderSLOURL,
Actual: response.Destination,
}
}
if response.Version != "2.0" {
return ErrInvalidValue{
Reason: ReasonUnsupported,
Key: "SAML version",
Expected: "2.0",
Actual: response.Version,
}
}
return nil
}
func xmlUnmarshalElement(el *etree.Element, obj interface{}) error {
doc := etree.NewDocument()
doc.SetRoot(el)
data, err := doc.WriteToBytes()
if err != nil {
return err
}
err = xml.Unmarshal(data, obj)
if err != nil {
return err
}
return nil
}
func (sp *SAMLServiceProvider) getDecryptCert() (*tls.Certificate, error) {
if sp.SPKeyStore == nil {
return nil, fmt.Errorf("no decryption certs available")
}
//This is the tls.Certificate we'll use to decrypt any encrypted assertions
var decryptCert tls.Certificate
switch crt := sp.SPKeyStore.(type) {
case dsig.TLSCertKeyStore:
// Get the tls.Certificate directly if possible
decryptCert = tls.Certificate(crt)
default:
//Otherwise, construct one from the results of GetKeyPair
pk, cert, err := sp.SPKeyStore.GetKeyPair()
if err != nil {
return nil, fmt.Errorf("error getting keypair: %v", err)
}
decryptCert = tls.Certificate{
Certificate: [][]byte{cert},
PrivateKey: pk,
}
}
if sp.ValidateEncryptionCert {
// Check Validity period of certificate
if len(decryptCert.Certificate) < 1 || len(decryptCert.Certificate[0]) < 1 {
return nil, fmt.Errorf("empty decryption cert")
} else if cert, err := x509.ParseCertificate(decryptCert.Certificate[0]); err != nil {
return nil, fmt.Errorf("invalid x509 decryption cert: %v", err)
} else {
now := sp.Clock.Now()
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
return nil, fmt.Errorf("decryption cert is not valid at this time")
}
}
}
return &decryptCert, nil
}
func (sp *SAMLServiceProvider) decryptAssertions(el *etree.Element) error {
var decryptCert *tls.Certificate
decryptAssertion := func(ctx etreeutils.NSContext, encryptedElement *etree.Element) error {
if encryptedElement.Parent() != el {
return fmt.Errorf("found encrypted assertion with unexpected parent element: %s", encryptedElement.Parent().Tag)
}
detached, err := etreeutils.NSDetatch(ctx, encryptedElement) // make a detached copy
if err != nil {
return fmt.Errorf("unable to detach encrypted assertion: %v", err)
}
encryptedAssertion := &types.EncryptedAssertion{}
err = xmlUnmarshalElement(detached, encryptedAssertion)
if err != nil {
return fmt.Errorf("unable to unmarshal encrypted assertion: %v", err)
}
if decryptCert == nil {
decryptCert, err = sp.getDecryptCert()
if err != nil {
return fmt.Errorf("unable to get decryption certificate: %v", err)
}
}
raw, derr := encryptedAssertion.DecryptBytes(decryptCert)
if derr != nil {
return fmt.Errorf("unable to decrypt encrypted assertion: %v", derr)
}
doc, _, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return fmt.Errorf("unable to create element from decrypted assertion bytes: %v", derr)
}
// Replace the original encrypted assertion with the decrypted one.
if el.RemoveChild(encryptedElement) == nil {
// Out of an abundance of caution, make sure removed worked
panic("unable to remove encrypted assertion")
}
el.AddChild(doc.Root())
return nil
}
return etreeutils.NSFindIterate(el, SAMLAssertionNamespace, EncryptedAssertionTag, decryptAssertion)
}
func (sp *SAMLServiceProvider) validateElementSignature(el *etree.Element) (*etree.Element, error) {
return sp.validationContext().Validate(el)
}
func (sp *SAMLServiceProvider) validateAssertionSignatures(el *etree.Element) error {
signedAssertions := 0
unsignedAssertions := 0
validateAssertion := func(ctx etreeutils.NSContext, unverifiedAssertion *etree.Element) error {
parent := unverifiedAssertion.Parent()
if parent == nil {
return fmt.Errorf("parent is nil")
}
if parent != el {
return fmt.Errorf("found assertion with unexpected parent element: %s", unverifiedAssertion.Parent().Tag)
}
detached, err := etreeutils.NSDetatch(ctx, unverifiedAssertion) // make a detached copy
if err != nil {
return fmt.Errorf("unable to detach unverified assertion: %v", err)
}
assertion, err := sp.validationContext().Validate(detached)
if err == dsig.ErrMissingSignature {
unsignedAssertions++
return nil
} else if err != nil {
return err
}
// Replace the original unverified Assertion with the verified one. Note that
// if the Response is not signed, only signed Assertions (and not the parent Response) can be trusted.
if el.RemoveChild(unverifiedAssertion) == nil {
// Out of an abundance of caution, check to make sure an Assertion was actually
// removed. If it wasn't a programming error has occurred.
panic("unable to remove assertion")
}
el.AddChild(assertion)
signedAssertions++
return nil
}
if err := etreeutils.NSFindIterate(el, SAMLAssertionNamespace, AssertionTag, validateAssertion); err != nil {
return err
} else if signedAssertions > 0 && unsignedAssertions > 0 {
return fmt.Errorf("invalid to have both signed and unsigned assertions")
} else if signedAssertions < 1 {
return dsig.ErrMissingSignature
} else {
return nil
}
}
// ValidateEncodedResponse both decodes and validates, based on SP
// configuration, an encoded, signed response. It will also appropriately
// decrypt a response if the assertion was encrypted
func (sp *SAMLServiceProvider) ValidateEncodedResponse(encodedResponse string) (*types.Response, error) {
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
if err != nil {
return nil, err
}
// Parse the raw response
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
elAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, AssertionTag)
if err != nil {
return nil, err
}
elEncAssertion, err := etreeutils.NSFindOne(el, SAMLAssertionNamespace, EncryptedAssertionTag)
if err != nil {
return nil, err
}
// We verify that either one of assertion or encrypted assertion elements are present,
// but not both.
if (elAssertion == nil) == (elEncAssertion == nil) {
return nil, fmt.Errorf("found both or no assertion and encrypted assertion elements")
}
// And if a decryptCert is present, then it's only encrypted assertion elements.
if sp.SPKeyStore != nil && elAssertion != nil {
return nil, fmt.Errorf("all assertions are not encrypted")
}
var responseSignatureValidated bool
if !sp.SkipSignatureValidation {
el, err = sp.validateElementSignature(el)
if err == dsig.ErrMissingSignature {
// Unfortunately we just blew away our Response
el = doc.Root()
} else if err != nil {
return nil, err
} else if el == nil {
return nil, fmt.Errorf("missing transformed response")
} else {
responseSignatureValidated = true
}
}
err = sp.decryptAssertions(el)
if err != nil {
return nil, err
}
var assertionSignaturesValidated bool
if !sp.SkipSignatureValidation {
err = sp.validateAssertionSignatures(el)
if err == dsig.ErrMissingSignature {
if !responseSignatureValidated {
return nil, fmt.Errorf("response and/or assertions must be signed")
}
} else if err != nil {
return nil, err
} else {
assertionSignaturesValidated = true
}
}
decodedResponse := &types.Response{}
err = xmlUnmarshalElement(el, decodedResponse)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal response: %v", err)
}
decodedResponse.SignatureValidated = responseSignatureValidated
if assertionSignaturesValidated {
for idx := 0; idx < len(decodedResponse.Assertions); idx++ {
decodedResponse.Assertions[idx].SignatureValidated = true
}
}
err = sp.Validate(decodedResponse)
if err != nil {
return nil, err
}
return decodedResponse, nil
}
// DecodeUnverifiedBaseResponse decodes several attributes from a SAML response for the purpose
// of determining how to validate the response. This is useful for Service Providers which
// expose a single Assertion Consumer Service URL but consume Responses from many IdPs.
func DecodeUnverifiedBaseResponse(encodedResponse string) (*types.UnverifiedBaseResponse, error) {
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
if err != nil {
return nil, err
}
var response *types.UnverifiedBaseResponse
err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
response = &types.UnverifiedBaseResponse{}
return xml.Unmarshal(maybeXML, response)
})
if err != nil {
return nil, err
}
return response, nil
}
// maybeDeflate invokes the passed decoder over the passed data. If an error is
// returned, it then attempts to deflate the passed data before re-invoking
// the decoder over the deflated data.
func maybeDeflate(data []byte, maxSize int64, decoder func([]byte) error) error {
err := decoder(data)
if err == nil {
return nil
}
// Default to 5MB max size
if maxSize == 0 {
maxSize = defaultMaxDecompressedResponseSize
}
lr := io.LimitReader(flate.NewReader(bytes.NewReader(data)), maxSize+1)
deflated, err := io.ReadAll(lr)
if err != nil {
return err
}
if int64(len(deflated)) > maxSize {
return fmt.Errorf("deflated response exceeds maximum size of %d bytes", maxSize)
}
return decoder(deflated)
}
// parseResponse is a helper function that was refactored out so that the XML parsing behavior can be isolated and unit tested
func parseResponse(xml []byte, maxSize int64) (*etree.Document, *etree.Element, error) {
var doc *etree.Document
var rawXML []byte
err := maybeDeflate(xml, maxSize, func(xml []byte) error {
doc = etree.NewDocument()
rawXML = xml
return doc.ReadFromBytes(xml)
})
if err != nil {
return nil, nil, err
}
el := doc.Root()
if el == nil {
return nil, nil, fmt.Errorf("unable to parse response")
}
// Examine the response for attempts to exploit weaknesses in Go's encoding/xml
err = rtvalidator.Validate(bytes.NewReader(rawXML))
if err != nil {
return nil, nil, err
}
return doc, el, nil
}
// DecodeUnverifiedLogoutResponse decodes several attributes from a SAML Logout response, without doing any verifications.
func DecodeUnverifiedLogoutResponse(encodedResponse string) (*types.LogoutResponse, error) {
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
if err != nil {
return nil, err
}
var response *types.LogoutResponse
err = maybeDeflate(raw, defaultMaxDecompressedResponseSize, func(maybeXML []byte) error {
response = &types.LogoutResponse{}
return xml.Unmarshal(maybeXML, response)
})
if err != nil {
return nil, err
}
return response, nil
}
func (sp *SAMLServiceProvider) ValidateEncodedLogoutResponsePOST(encodedResponse string) (*types.LogoutResponse, error) {
raw, err := base64.StdEncoding.DecodeString(encodedResponse)
if err != nil {
return nil, err
}
// Parse the raw response
doc, el, err := parseResponse(raw, sp.MaximumDecompressedBodySize)
if err != nil {
return nil, err
}
var responseSignatureValidated bool
if !sp.SkipSignatureValidation {
el, err = sp.validateElementSignature(el)
if err == dsig.ErrMissingSignature {
// Unfortunately we just blew away our Response
el = doc.Root()
} else if err != nil {
return nil, err
} else if el == nil {
return nil, fmt.Errorf("missing transformed logout response")
} else {
responseSignatureValidated = true
}
}
decodedResponse := &types.LogoutResponse{}
err = xmlUnmarshalElement(el, decodedResponse)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal logout response: %v", err)
}
decodedResponse.SignatureValidated = responseSignatureValidated
err = sp.ValidateDecodedLogoutResponse(decodedResponse)
if err != nil {
return nil, err
}
return decodedResponse, nil
}

37
vendor/github.com/mattermost/gosaml2/logout_request.go generated vendored Normal file
View File

@ -0,0 +1,37 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import (
"encoding/xml"
"github.com/mattermost/gosaml2/types"
"time"
)
// LogoutRequest is the go struct representation of a logout request
type LogoutRequest struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol LogoutRequest"`
ID string `xml:"ID,attr"`
Version string `xml:"Version,attr"`
//ProtocolBinding string `xml:",attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Destination string `xml:"Destination,attr"`
Issuer *types.Issuer `xml:"Issuer"`
NameID *types.NameID `xml:"NameID"`
SignatureValidated bool `xml:"-"` // not read, not dumped
}

View File

@ -0,0 +1,111 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import "fmt"
//ErrMissingElement is the error type that indicates an element and/or attribute is
//missing. It provides a structured error that can be more appropriately acted
//upon.
type ErrMissingElement struct {
Tag, Attribute string
}
type ErrVerification struct {
Cause error
}
func (e ErrVerification) Error() string {
return fmt.Sprintf("error validating response: %s", e.Cause.Error())
}
//ErrMissingAssertion indicates that an appropriate assertion element could not
//be found in the SAML Response
var (
ErrMissingAssertion = ErrMissingElement{Tag: AssertionTag}
)
func (e ErrMissingElement) Error() string {
if e.Attribute != "" {
return fmt.Sprintf("missing %s attribute on %s element", e.Attribute, e.Tag)
}
return fmt.Sprintf("missing %s element", e.Tag)
}
//RetrieveAssertionInfo takes an encoded response and returns the AssertionInfo
//contained, or an error message if an error has been encountered.
func (sp *SAMLServiceProvider) RetrieveAssertionInfo(encodedResponse string) (*AssertionInfo, error) {
assertionInfo := &AssertionInfo{
Values: make(Values),
}
response, err := sp.ValidateEncodedResponse(encodedResponse)
if err != nil {
return nil, ErrVerification{Cause: err}
}
// TODO: Support multiple assertions
if len(response.Assertions) == 0 {
return nil, ErrMissingAssertion
}
assertion := response.Assertions[0]
assertionInfo.Assertions = response.Assertions
assertionInfo.ResponseSignatureValidated = response.SignatureValidated
warningInfo, err := sp.VerifyAssertionConditions(&assertion)
if err != nil {
return nil, err
}
//Get the NameID
subject := assertion.Subject
if subject == nil {
return nil, ErrMissingElement{Tag: SubjectTag}
}
nameID := subject.NameID
if nameID == nil {
return nil, ErrMissingElement{Tag: NameIdTag}
}
assertionInfo.NameID = nameID.Value
//Get the actual assertion attributes
attributeStatement := assertion.AttributeStatement
if attributeStatement == nil && !sp.AllowMissingAttributes {
return nil, ErrMissingElement{Tag: AttributeStatementTag}
}
if attributeStatement != nil {
for _, attribute := range attributeStatement.Attributes {
assertionInfo.Values[attribute.Name] = attribute
}
}
if assertion.AuthnStatement != nil {
if assertion.AuthnStatement.AuthnInstant != nil {
assertionInfo.AuthnInstant = assertion.AuthnStatement.AuthnInstant
}
if assertion.AuthnStatement.SessionNotOnOrAfter != nil {
assertionInfo.SessionNotOnOrAfter = assertion.AuthnStatement.SessionNotOnOrAfter
}
assertionInfo.SessionIndex = assertion.AuthnStatement.SessionIndex
}
assertionInfo.WarningInfo = warningInfo
return assertionInfo, nil
}

12
vendor/github.com/mattermost/gosaml2/run_test.sh generated vendored Normal file
View File

@ -0,0 +1,12 @@
#!/bin/bash
cd `dirname $0`
DIRS=`git grep -l 'func Test' | xargs dirname | sort -u`
for DIR in $DIRS
do
echo
echo "dir: $DIR"
echo "======================================"
pushd $DIR >/dev/null
go test -v || exit 1
popd >/dev/null
done

291
vendor/github.com/mattermost/gosaml2/saml.go generated vendored Normal file
View File

@ -0,0 +1,291 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import (
"encoding/base64"
"sync"
"time"
"github.com/mattermost/gosaml2/types"
dsig "github.com/russellhaering/goxmldsig"
dsigtypes "github.com/russellhaering/goxmldsig/types"
)
type ErrSaml struct {
Message string
System error
}
func (serr ErrSaml) Error() string {
if serr.Message != "" {
return serr.Message
}
return "SAML error"
}
type SAMLServiceProvider struct {
IdentityProviderSSOURL string
IdentityProviderSSOBinding string
IdentityProviderSLOURL string
IdentityProviderSLOBinding string
IdentityProviderIssuer string
AssertionConsumerServiceURL string
ServiceProviderSLOURL string
ServiceProviderIssuer string
SignAuthnRequests bool
SignAuthnRequestsAlgorithm string
SignAuthnRequestsCanonicalizer dsig.Canonicalizer
// RequestedAuthnContext allows service providers to require that the identity
// provider use specific authentication mechanisms. Leaving this unset will
// permit the identity provider to choose the auth method. To maximize compatibility
// with identity providers it is recommended to leave this unset.
RequestedAuthnContext *RequestedAuthnContext
AudienceURI string
IDPCertificateStore dsig.X509CertificateStore
SPKeyStore dsig.X509KeyStore // Required encryption key, default signing key
SPSigningKeyStore dsig.X509KeyStore // Optional signing key
NameIdFormat string
ValidateEncryptionCert bool
SkipSignatureValidation bool
AllowMissingAttributes bool
ScopingIDPProviderId string
ScopingIDPProviderName string
Clock *dsig.Clock
// MaximumDecompressedBodySize is the maximum size to which a compressed
// SAML document will be decompressed. If a compresed document is exceeds
// this size during decompression an error will be returned.
MaximumDecompressedBodySize int64
signingContextMu sync.RWMutex
signingContext *dsig.SigningContext
}
// RequestedAuthnContext controls which authentication mechanisms are requested of
// the identity provider. It is generally sufficient to omit this and let the
// identity provider select an authentication mechansim.
type RequestedAuthnContext struct {
// The RequestedAuthnContext comparison policy to use. See the section 3.3.2.2.1
// of the SAML 2.0 specification for details. Constants named AuthnPolicyMatch*
// contain standardized values.
Comparison string
// Contexts will be passed as AuthnContextClassRefs. For example, to force password
// authentication on some identity providers, Contexts should have a value of
// []string{AuthnContextPasswordProtectedTransport}, and Comparison should have a
// value of AuthnPolicyMatchExact.
Contexts []string
}
func (sp *SAMLServiceProvider) Metadata() (*types.EntityDescriptor, error) {
keyDescriptors := make([]types.KeyDescriptor, 0, 2)
if sp.GetSigningKey() != nil {
signingCertBytes, err := sp.GetSigningCertBytes()
if err != nil {
return nil, err
}
keyDescriptors = append(keyDescriptors, types.KeyDescriptor{
Use: "signing",
KeyInfo: dsigtypes.KeyInfo{
X509Data: dsigtypes.X509Data{
X509Certificates: []dsigtypes.X509Certificate{{
Data: base64.StdEncoding.EncodeToString(signingCertBytes),
}},
},
},
})
}
if sp.GetEncryptionKey() != nil {
encryptionCertBytes, err := sp.GetEncryptionCertBytes()
if err != nil {
return nil, err
}
keyDescriptors = append(keyDescriptors, types.KeyDescriptor{
Use: "encryption",
KeyInfo: dsigtypes.KeyInfo{
X509Data: dsigtypes.X509Data{
X509Certificates: []dsigtypes.X509Certificate{{
Data: base64.StdEncoding.EncodeToString(encryptionCertBytes),
}},
},
},
EncryptionMethods: []types.EncryptionMethod{
{Algorithm: types.MethodAES128GCM},
{Algorithm: types.MethodAES192GCM},
{Algorithm: types.MethodAES256GCM},
{Algorithm: types.MethodAES128CBC},
{Algorithm: types.MethodAES256CBC},
},
})
}
return &types.EntityDescriptor{
ValidUntil: time.Now().UTC().Add(time.Hour * 24 * 7), // 7 days
EntityID: sp.ServiceProviderIssuer,
SPSSODescriptor: &types.SPSSODescriptor{
AuthnRequestsSigned: sp.SignAuthnRequests,
WantAssertionsSigned: !sp.SkipSignatureValidation,
ProtocolSupportEnumeration: SAMLProtocolNamespace,
KeyDescriptors: keyDescriptors,
AssertionConsumerServices: []types.IndexedEndpoint{{
Binding: BindingHttpPost,
Location: sp.AssertionConsumerServiceURL,
Index: 1,
}},
},
}, nil
}
func (sp *SAMLServiceProvider) MetadataWithSLO(validityHours int64) (*types.EntityDescriptor, error) {
signingCertBytes, err := sp.GetSigningCertBytes()
if err != nil {
return nil, err
}
encryptionCertBytes, err := sp.GetEncryptionCertBytes()
if err != nil {
return nil, err
}
if validityHours <= 0 {
//By default let's keep it to 7 days.
validityHours = int64(time.Hour * 24 * 7)
}
return &types.EntityDescriptor{
ValidUntil: time.Now().UTC().Add(time.Duration(validityHours)), // default 7 days
EntityID: sp.ServiceProviderIssuer,
SPSSODescriptor: &types.SPSSODescriptor{
AuthnRequestsSigned: sp.SignAuthnRequests,
WantAssertionsSigned: !sp.SkipSignatureValidation,
ProtocolSupportEnumeration: SAMLProtocolNamespace,
KeyDescriptors: []types.KeyDescriptor{
{
Use: "signing",
KeyInfo: dsigtypes.KeyInfo{
X509Data: dsigtypes.X509Data{
X509Certificates: []dsigtypes.X509Certificate{{
Data: base64.StdEncoding.EncodeToString(signingCertBytes),
}},
},
},
},
{
Use: "encryption",
KeyInfo: dsigtypes.KeyInfo{
X509Data: dsigtypes.X509Data{
X509Certificates: []dsigtypes.X509Certificate{{
Data: base64.StdEncoding.EncodeToString(encryptionCertBytes),
}},
},
},
EncryptionMethods: []types.EncryptionMethod{
{Algorithm: types.MethodAES128GCM, DigestMethod: nil},
{Algorithm: types.MethodAES192GCM, DigestMethod: nil},
{Algorithm: types.MethodAES256GCM, DigestMethod: nil},
{Algorithm: types.MethodAES128CBC, DigestMethod: nil},
{Algorithm: types.MethodAES256CBC, DigestMethod: nil},
},
},
},
AssertionConsumerServices: []types.IndexedEndpoint{{
Binding: BindingHttpPost,
Location: sp.AssertionConsumerServiceURL,
Index: 1,
}},
SingleLogoutServices: []types.Endpoint{{
Binding: BindingHttpPost,
Location: sp.ServiceProviderSLOURL,
}},
},
}, nil
}
func (sp *SAMLServiceProvider) GetEncryptionKey() dsig.X509KeyStore {
return sp.SPKeyStore
}
func (sp *SAMLServiceProvider) GetSigningKey() dsig.X509KeyStore {
if sp.SPSigningKeyStore == nil {
return sp.GetEncryptionKey() // Default is signing key is same as encryption key
}
return sp.SPSigningKeyStore
}
func (sp *SAMLServiceProvider) GetEncryptionCertBytes() ([]byte, error) {
if _, encryptionCert, err := sp.GetEncryptionKey().GetKeyPair(); err != nil {
return nil, ErrSaml{Message: "no SP encryption certificate", System: err}
} else if len(encryptionCert) < 1 {
return nil, ErrSaml{Message: "empty SP encryption certificate"}
} else {
return encryptionCert, nil
}
}
func (sp *SAMLServiceProvider) GetSigningCertBytes() ([]byte, error) {
if _, signingCert, err := sp.GetSigningKey().GetKeyPair(); err != nil {
return nil, ErrSaml{Message: "no SP signing certificate", System: err}
} else if len(signingCert) < 1 {
return nil, ErrSaml{Message: "empty SP signing certificate"}
} else {
return signingCert, nil
}
}
func (sp *SAMLServiceProvider) SigningContext() *dsig.SigningContext {
sp.signingContextMu.RLock()
signingContext := sp.signingContext
sp.signingContextMu.RUnlock()
if signingContext != nil {
return signingContext
}
sp.signingContextMu.Lock()
defer sp.signingContextMu.Unlock()
sp.signingContext = dsig.NewDefaultSigningContext(sp.GetSigningKey())
sp.signingContext.SetSignatureMethod(sp.SignAuthnRequestsAlgorithm)
if sp.SignAuthnRequestsCanonicalizer != nil {
sp.signingContext.Canonicalizer = sp.SignAuthnRequestsCanonicalizer
}
return sp.signingContext
}
type ProxyRestriction struct {
Count int
Audience []string
}
type WarningInfo struct {
OneTimeUse bool
ProxyRestriction *ProxyRestriction
NotInAudience bool
InvalidTime bool
}
type AssertionInfo struct {
NameID string
Values Values
WarningInfo *WarningInfo
SessionIndex string
AuthnInstant *time.Time
SessionNotOnOrAfter *time.Time
Assertions []types.Assertion
ResponseSignatureValidated bool
}

418
vendor/github.com/mattermost/gosaml2/test_constants.go generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,97 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"bytes"
"crypto/cipher"
"crypto/tls"
"encoding/base64"
"encoding/xml"
"fmt"
)
type EncryptedAssertion struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion EncryptedAssertion"`
EncryptionMethod EncryptionMethod `xml:"EncryptedData>EncryptionMethod"`
EncryptedKey EncryptedKey `xml:"EncryptedData>KeyInfo>EncryptedKey"`
DetEncryptedKey EncryptedKey `xml:"EncryptedKey"` // detached EncryptedKey element
CipherValue string `xml:"EncryptedData>CipherData>CipherValue"`
}
func (ea *EncryptedAssertion) DecryptBytes(cert *tls.Certificate) ([]byte, error) {
data, err := base64.StdEncoding.DecodeString(ea.CipherValue)
if err != nil {
return nil, err
}
// EncryptedKey must include CipherValue. EncryptedKey may be part of EncryptedData.
ek := &ea.EncryptedKey
if ek.CipherValue == "" {
// Use detached EncryptedKey element (sibling of EncryptedData). See:
// https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html#sec-Extensions-to-KeyInfo
ek = &ea.DetEncryptedKey
}
k, err := ek.DecryptSymmetricKey(cert)
if err != nil {
return nil, fmt.Errorf("cannot decrypt, error retrieving private key: %s", err)
}
switch ea.EncryptionMethod.Algorithm {
case MethodAES128GCM, MethodAES192GCM, MethodAES256GCM:
c, err := cipher.NewGCM(k)
if err != nil {
return nil, fmt.Errorf("cannot create AES-GCM: %s", err)
}
nonce, data := data[:c.NonceSize()], data[c.NonceSize():]
plainText, err := c.Open(nil, nonce, data, nil)
if err != nil {
return nil, fmt.Errorf("cannot open AES-GCM: %s", err)
}
return plainText, nil
case MethodAES128CBC, MethodAES256CBC, MethodTripleDESCBC:
nonce, data := data[:k.BlockSize()], data[k.BlockSize():]
c := cipher.NewCBCDecrypter(k, nonce)
c.CryptBlocks(data, data)
// Remove zero bytes
data = bytes.TrimRight(data, "\x00")
// Calculate index to remove based on padding
padLength := data[len(data)-1]
lastGoodIndex := len(data) - int(padLength)
return data[:lastGoodIndex], nil
default:
return nil, fmt.Errorf("unknown symmetric encryption method %#v", ea.EncryptionMethod.Algorithm)
}
}
// Decrypt decrypts and unmarshals the EncryptedAssertion.
func (ea *EncryptedAssertion) Decrypt(cert *tls.Certificate) (*Assertion, error) {
plaintext, err := ea.DecryptBytes(cert)
if err != nil {
return nil, fmt.Errorf("Error decrypting assertion: %v", err)
}
assertion := &Assertion{}
err = xml.Unmarshal(plaintext, assertion)
if err != nil {
return nil, fmt.Errorf("Error unmarshaling assertion: %v", err)
}
return assertion, nil
}

View File

@ -0,0 +1,196 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"encoding/base64"
"encoding/hex"
"fmt"
"hash"
"strings"
)
//EncryptedKey contains the decryption key data from the saml2 core and xmlenc
//standards.
type EncryptedKey struct {
// EncryptionMethod string `xml:"EncryptionMethod>Algorithm"`
X509Data string `xml:"KeyInfo>X509Data>X509Certificate"`
CipherValue string `xml:"CipherData>CipherValue"`
EncryptionMethod EncryptionMethod
}
//EncryptionMethod specifies the type of encryption that was used.
type EncryptionMethod struct {
Algorithm string `xml:",attr,omitempty"`
//Digest method is present for algorithms like RSA-OAEP.
//See https://www.w3.org/TR/xmlenc-core1/.
//To convey the digest methods an entity supports,
//DigestMethod in extensions element is used.
//See http://docs.oasis-open.org/security/saml/Post2.0/sstc-saml-metadata-algsupport.html.
DigestMethod *DigestMethod `xml:",omitempty"`
}
//DigestMethod is a digest type specification
type DigestMethod struct {
Algorithm string `xml:",attr,omitempty"`
}
//Well-known public-key encryption methods
const (
MethodRSAOAEP = "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"
MethodRSAOAEP2 = "http://www.w3.org/2009/xmlenc11#rsa-oaep"
MethodRSAv1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5"
)
//Well-known private key encryption methods
const (
MethodAES128GCM = "http://www.w3.org/2009/xmlenc11#aes128-gcm"
MethodAES192GCM = "http://www.w3.org/2009/xmlenc11#aes192-gcm"
MethodAES256GCM = "http://www.w3.org/2009/xmlenc11#aes256-gcm"
MethodAES128CBC = "http://www.w3.org/2001/04/xmlenc#aes128-cbc"
MethodAES256CBC = "http://www.w3.org/2001/04/xmlenc#aes256-cbc"
MethodTripleDESCBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc"
)
//Well-known hash methods
const (
MethodSHA1 = "http://www.w3.org/2000/09/xmldsig#sha1"
MethodSHA256 = "http://www.w3.org/2000/09/xmldsig#sha256"
MethodSHA512 = "http://www.w3.org/2000/09/xmldsig#sha512"
)
//SHA-1 is commonly used for certificate fingerprints (openssl -fingerprint and ADFS thumbprint).
//SHA-1 is sufficient for our purposes here (error message).
func debugKeyFp(keyBytes []byte) string {
if len(keyBytes) < 1 {
return ""
}
hashFunc := sha1.New()
hashFunc.Write(keyBytes)
sum := strings.ToLower(hex.EncodeToString(hashFunc.Sum(nil)))
var ret string
for idx := 0; idx+1 < len(sum); idx += 2 {
if idx == 0 {
ret += sum[idx : idx+2]
} else {
ret += ":" + sum[idx:idx+2]
}
}
return ret
}
//DecryptSymmetricKey returns the private key contained in the EncryptedKey document
func (ek *EncryptedKey) DecryptSymmetricKey(cert *tls.Certificate) (cipher.Block, error) {
if len(cert.Certificate) < 1 {
return nil, fmt.Errorf("decryption tls.Certificate has no public certs attached")
}
// The EncryptedKey may or may not include X509Data (certificate).
// If included, the EncryptedKey certificate:
// - is FYI only (fail if it does not match the SP certificate)
// - is NOT used to decrypt CipherData
if ek.X509Data != "" {
if encCert, err := base64.StdEncoding.DecodeString(ek.X509Data); err != nil {
return nil, fmt.Errorf("error decoding EncryptedKey certificate: %v", err)
} else if !bytes.Equal(cert.Certificate[0], encCert) {
return nil, fmt.Errorf("key decryption attempted with mismatched cert, SP cert(%.11s), assertion cert(%.11s)",
debugKeyFp(cert.Certificate[0]), debugKeyFp(encCert))
}
}
cipherText, err := base64.StdEncoding.DecodeString(ek.CipherValue)
if err != nil {
return nil, err
}
switch pk := cert.PrivateKey.(type) {
case *rsa.PrivateKey:
var h hash.Hash
if ek.EncryptionMethod.DigestMethod == nil {
//if digest method is not present lets set default method to SHA1.
//Digest method is used by methods like RSA-OAEP.
h = sha1.New()
} else {
switch ek.EncryptionMethod.DigestMethod.Algorithm {
case "", MethodSHA1:
h = sha1.New() // default
case MethodSHA256:
h = sha256.New()
case MethodSHA512:
h = sha512.New()
default:
return nil, fmt.Errorf("unsupported digest algorithm: %v",
ek.EncryptionMethod.DigestMethod.Algorithm)
}
}
switch ek.EncryptionMethod.Algorithm {
case "":
return nil, fmt.Errorf("missing encryption algorithm")
case MethodRSAOAEP, MethodRSAOAEP2:
pt, err := rsa.DecryptOAEP(h, rand.Reader, pk, cipherText, nil)
if err != nil {
return nil, fmt.Errorf("rsa internal error: %v", err)
}
b, err := aes.NewCipher(pt)
if err != nil {
return nil, err
}
return b, nil
case MethodRSAv1_5:
pt, err := rsa.DecryptPKCS1v15(rand.Reader, pk, cipherText)
if err != nil {
return nil, fmt.Errorf("rsa internal error: %v", err)
}
//From https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf the xml encryption
//methods to be supported are from http://www.w3.org/2001/04/xmlenc#Element.
//https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/Overview.html#Element.
//https://www.w3.org/TR/2002/REC-xmlenc-core-20021210/#sec-Algorithms
//Sec 5.4 Key Transport:
//The RSA v1.5 Key Transport algorithm given below are those used in conjunction with TRIPLEDES
//Please also see https://www.w3.org/TR/xmlenc-core/#sec-Algorithms and
//https://www.w3.org/TR/xmlenc-core/#rsav15note.
b, err := des.NewTripleDESCipher(pt)
if err != nil {
return nil, err
}
// FIXME: The version we had previously in our fork, AES seems more secure from my Googling.
// b, err := aes.NewCipher(pt)
// if err != nil {
// return nil, err
// }
return b, nil
default:
return nil, fmt.Errorf("unsupported encryption algorithm: %s", ek.EncryptionMethod.Algorithm)
}
}
return nil, fmt.Errorf("no cipher for decoding symmetric key")
}

102
vendor/github.com/mattermost/gosaml2/types/metadata.go generated vendored Normal file
View File

@ -0,0 +1,102 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"encoding/xml"
"time"
dsigtypes "github.com/russellhaering/goxmldsig/types"
)
type EntityDescriptor struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"`
ValidUntil time.Time `xml:"validUntil,attr"`
// SAML 2.0 8.3.6 Entity Identifier could be used to represent issuer
EntityID string `xml:"entityID,attr"`
SPSSODescriptor *SPSSODescriptor `xml:"SPSSODescriptor,omitempty"`
IDPSSODescriptor *IDPSSODescriptor `xml:"IDPSSODescriptor,omitempty"`
Extensions *Extensions `xml:"Extensions,omitempty"`
}
type Endpoint struct {
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
ResponseLocation string `xml:"ResponseLocation,attr,omitempty"`
}
type IndexedEndpoint struct {
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
Index int `xml:"index,attr"`
}
type SPSSODescriptor struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata SPSSODescriptor"`
AuthnRequestsSigned bool `xml:"AuthnRequestsSigned,attr"`
WantAssertionsSigned bool `xml:"WantAssertionsSigned,attr"`
ProtocolSupportEnumeration string `xml:"protocolSupportEnumeration,attr"`
KeyDescriptors []KeyDescriptor `xml:"KeyDescriptor"`
SingleLogoutServices []Endpoint `xml:"SingleLogoutService"`
NameIDFormats []string `xml:"NameIDFormat"`
AssertionConsumerServices []IndexedEndpoint `xml:"AssertionConsumerService"`
Extensions *Extensions `xml:"Extensions,omitempty"`
}
type IDPSSODescriptor struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata IDPSSODescriptor"`
WantAuthnRequestsSigned bool `xml:"WantAuthnRequestsSigned,attr"`
KeyDescriptors []KeyDescriptor `xml:"KeyDescriptor"`
NameIDFormats []NameIDFormat `xml:"NameIDFormat"`
SingleSignOnServices []SingleSignOnService `xml:"SingleSignOnService"`
SingleLogoutServices []SingleLogoutService `xml:"SingleLogoutService"`
Attributes []Attribute `xml:"Attribute"`
Extensions *Extensions `xml:"Extensions,omitempty"`
}
type KeyDescriptor struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata KeyDescriptor"`
Use string `xml:"use,attr"`
KeyInfo dsigtypes.KeyInfo `xml:"KeyInfo"`
EncryptionMethods []EncryptionMethod `xml:"EncryptionMethod"`
}
type NameIDFormat struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata NameIDFormat"`
Value string `xml:",chardata"`
}
type SingleSignOnService struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleSignOnService"`
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
}
type SingleLogoutService struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata SingleLogoutService"`
Binding string `xml:"Binding,attr"`
Location string `xml:"Location,attr"`
}
type SigningMethod struct {
Algorithm string `xml:",attr"`
MinKeySize string `xml:"MinKeySize,attr,omitempty"`
MaxKeySize string `xml:"MaxKeySize,attr,omitempty"`
}
type Extensions struct {
DigestMethod *DigestMethod `xml:",omitempty"`
SigningMethod *SigningMethod `xml:",omitempty"`
}

187
vendor/github.com/mattermost/gosaml2/types/response.go generated vendored Normal file
View File

@ -0,0 +1,187 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"encoding/xml"
"time"
)
// UnverifiedBaseResponse extracts several basic attributes of a SAML Response
// which may be useful in deciding how to validate the Response. An UnverifiedBaseResponse
// is parsed by this library prior to any validation of the Response, so the
// values it contains may have been supplied by an attacker and should not be
// trusted as authoritative from the IdP.
type UnverifiedBaseResponse struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"`
ID string `xml:"ID,attr"`
InResponseTo string `xml:"InResponseTo,attr"`
Destination string `xml:"Destination,attr"`
Version string `xml:"Version,attr"`
Issuer *Issuer `xml:"Issuer"`
}
type Response struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"`
ID string `xml:"ID,attr"`
InResponseTo string `xml:"InResponseTo,attr"`
Destination string `xml:"Destination,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Status *Status `xml:"Status"`
Issuer *Issuer `xml:"Issuer"`
Assertions []Assertion `xml:"Assertion"`
EncryptedAssertions []EncryptedAssertion `xml:"EncryptedAssertion"`
SignatureValidated bool `xml:"-"` // not read, not dumped
}
type LogoutResponse struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol LogoutResponse"`
ID string `xml:"ID,attr"`
InResponseTo string `xml:"InResponseTo,attr"`
Destination string `xml:"Destination,attr"`
Version string `xml:"Version,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Status *Status `xml:"Status"`
Issuer *Issuer `xml:"Issuer"`
SignatureValidated bool `xml:"-"` // not read, not dumped
}
type Status struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Status"`
StatusCode *StatusCode `xml:"StatusCode"`
}
type StatusCode struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol StatusCode"`
Value string `xml:"Value,attr"`
}
type Issuer struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Issuer"`
Value string `xml:",chardata"`
}
type Signature struct {
SignatureDocument []byte `xml:",innerxml"`
}
type Assertion struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Assertion"`
Version string `xml:"Version,attr"`
ID string `xml:"ID,attr"`
IssueInstant time.Time `xml:"IssueInstant,attr"`
Issuer *Issuer `xml:"Issuer"`
Signature *Signature `xml:"Signature"`
Subject *Subject `xml:"Subject"`
Conditions *Conditions `xml:"Conditions"`
AttributeStatement *AttributeStatement `xml:"AttributeStatement"`
AuthnStatement *AuthnStatement `xml:"AuthnStatement"`
SignatureValidated bool `xml:"-"` // not read, not dumped
}
type Subject struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Subject"`
NameID *NameID `xml:"NameID"`
SubjectConfirmation *SubjectConfirmation `xml:"SubjectConfirmation"`
}
type AuthnContext struct {
XMLName xml.Name `xml:urn:oasis:names:tc:SAML:2.0:assertion AuthnContext"`
AuthnContextClassRef *AuthnContextClassRef `xml:"AuthnContextClassRef"`
}
type AuthnContextClassRef struct {
XMLName xml.Name `xml:urn:oasis:names:tc:SAML:2.0:assertion AuthnContextClassRef"`
Value string `xml:",chardata"`
}
type NameID struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion NameID"`
Value string `xml:",chardata"`
}
type SubjectConfirmation struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmation"`
Method string `xml:"Method,attr"`
SubjectConfirmationData *SubjectConfirmationData `xml:"SubjectConfirmationData"`
}
type SubjectConfirmationData struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion SubjectConfirmationData"`
NotOnOrAfter string `xml:"NotOnOrAfter,attr"`
Recipient string `xml:"Recipient,attr"`
InResponseTo string `xml:"InResponseTo,attr"`
}
type Conditions struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Conditions"`
NotBefore string `xml:"NotBefore,attr"`
NotOnOrAfter string `xml:"NotOnOrAfter,attr"`
AudienceRestrictions []AudienceRestriction `xml:"AudienceRestriction"`
OneTimeUse *OneTimeUse `xml:"OneTimeUse"`
ProxyRestriction *ProxyRestriction `xml:"ProxyRestriction"`
}
type AudienceRestriction struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AudienceRestriction"`
Audiences []Audience `xml:"Audience"`
}
type Audience struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Audience"`
Value string `xml:",chardata"`
}
type OneTimeUse struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion OneTimeUse"`
}
type ProxyRestriction struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion ProxyRestriction"`
Count int `xml:"Count,attr"`
Audience []Audience `xml:"Audience"`
}
type AttributeStatement struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeStatement"`
Attributes []Attribute `xml:"Attribute"`
}
type Attribute struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"`
FriendlyName string `xml:"FriendlyName,attr"`
Name string `xml:"Name,attr"`
NameFormat string `xml:"NameFormat,attr"`
Values []AttributeValue `xml:"AttributeValue"`
}
type AttributeValue struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AttributeValue"`
Type string `xml:"xsi:type,attr"`
Value string `xml:",chardata"`
}
type AuthnStatement struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion AuthnStatement"`
//Section 4.1.4.2 - https://docs.oasis-open.org/security/saml/v2.0/saml-profiles-2.0-os.pdf
//If the identity provider supports the Single Logout profile, defined in Section 4.4
//, any such authentication statements MUST include a SessionIndex attribute to enable
//per-session logout requests by the service provider.
SessionIndex string `xml:"SessionIndex,attr,omitempty"`
AuthnInstant *time.Time `xml:"AuthnInstant,attr,omitempty"`
SessionNotOnOrAfter *time.Time `xml:"SessionNotOnOrAfter,attr,omitempty"`
AuthnContext *AuthnContext `xml:"AuthnContext"`
}

41
vendor/github.com/mattermost/gosaml2/uuid/uuid.go generated vendored Normal file
View File

@ -0,0 +1,41 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package uuid
// relevant bits from https://github.com/abneptis/GoUUID/blob/master/uuid.go
import (
"crypto/rand"
"fmt"
)
type UUID [16]byte
// NewV4 returns random generated UUID.
func NewV4() *UUID {
u := &UUID{}
_, err := rand.Read(u[:16])
if err != nil {
panic(err)
}
u[8] = (u[8] | 0x80) & 0xBf
u[6] = (u[6] | 0x40) & 0x4f
return u
}
func (u *UUID) String() string {
return fmt.Sprintf("%x-%x-%x-%x-%x", u[:4], u[4:6], u[6:8], u[8:10], u[10:])
}

309
vendor/github.com/mattermost/gosaml2/validate.go generated vendored Normal file
View File

@ -0,0 +1,309 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
import (
"fmt"
"time"
"github.com/mattermost/gosaml2/types"
)
//ErrParsing indicates that the value present in an assertion could not be
//parsed. It can be inspected for the specific tag name, the contents, and the
//intended type.
type ErrParsing struct {
Tag, Value, Type string
}
func (ep ErrParsing) Error() string {
return fmt.Sprintf("Error parsing %s tag value as type %s", ep.Tag, ep.Value)
}
//Oft-used messages
const (
ReasonUnsupported = "Unsupported"
ReasonExpired = "Expired"
)
//ErrInvalidValue indicates that the expected value did not match the received
//value.
type ErrInvalidValue struct {
Key, Expected, Actual string
Reason string
}
func (e ErrInvalidValue) Error() string {
if e.Reason == "" {
e.Reason = "Unrecognized"
}
return fmt.Sprintf("%s %s value, Expected: %s, Actual: %s", e.Reason, e.Key, e.Expected, e.Actual)
}
//Well-known methods of subject confirmation
const (
SubjMethodBearer = "urn:oasis:names:tc:SAML:2.0:cm:bearer"
)
//VerifyAssertionConditions inspects an assertion element and makes sure that
//all SAML2 contracts are upheld.
func (sp *SAMLServiceProvider) VerifyAssertionConditions(assertion *types.Assertion) (*WarningInfo, error) {
warningInfo := &WarningInfo{}
now := sp.Clock.Now()
conditions := assertion.Conditions
if conditions == nil {
return nil, ErrMissingElement{Tag: ConditionsTag}
}
if conditions.NotBefore == "" {
return nil, ErrMissingElement{Tag: ConditionsTag, Attribute: NotBeforeAttr}
}
notBefore, err := time.Parse(time.RFC3339, conditions.NotBefore)
if err != nil {
return nil, ErrParsing{Tag: NotBeforeAttr, Value: conditions.NotBefore, Type: "time.RFC3339"}
}
if now.Before(notBefore) {
warningInfo.InvalidTime = true
}
if conditions.NotOnOrAfter == "" {
return nil, ErrMissingElement{Tag: ConditionsTag, Attribute: NotOnOrAfterAttr}
}
notOnOrAfter, err := time.Parse(time.RFC3339, conditions.NotOnOrAfter)
if err != nil {
return nil, ErrParsing{Tag: NotOnOrAfterAttr, Value: conditions.NotOnOrAfter, Type: "time.RFC3339"}
}
if now.After(notOnOrAfter) {
warningInfo.InvalidTime = true
}
for _, audienceRestriction := range conditions.AudienceRestrictions {
matched := false
for _, audience := range audienceRestriction.Audiences {
if audience.Value == sp.AudienceURI {
matched = true
break
}
}
if !matched {
warningInfo.NotInAudience = true
break
}
}
if conditions.OneTimeUse != nil {
warningInfo.OneTimeUse = true
}
proxyRestriction := conditions.ProxyRestriction
if proxyRestriction != nil {
proxyRestrictionInfo := &ProxyRestriction{
Count: proxyRestriction.Count,
Audience: []string{},
}
for _, audience := range proxyRestriction.Audience {
proxyRestrictionInfo.Audience = append(proxyRestrictionInfo.Audience, audience.Value)
}
warningInfo.ProxyRestriction = proxyRestrictionInfo
}
return warningInfo, nil
}
//Validate ensures that the assertion passed is valid for the current Service
//Provider.
func (sp *SAMLServiceProvider) Validate(response *types.Response) error {
err := sp.validateResponseAttributes(response)
if err != nil {
return err
}
if len(response.Assertions) == 0 {
return ErrMissingAssertion
}
issuer := response.Issuer
if issuer == nil {
// FIXME?: SAML Core 2.0 Section 3.2.2 has Response.Issuer as [Optional]
return ErrMissingElement{Tag: IssuerTag}
}
if sp.IdentityProviderIssuer != "" && response.Issuer.Value != sp.IdentityProviderIssuer {
return ErrInvalidValue{
Key: IssuerTag,
Expected: sp.IdentityProviderIssuer,
Actual: response.Issuer.Value,
}
}
status := response.Status
if status == nil {
return ErrMissingElement{Tag: StatusTag}
}
statusCode := status.StatusCode
if statusCode == nil {
return ErrMissingElement{Tag: StatusCodeTag}
}
if statusCode.Value != StatusCodeSuccess {
return ErrInvalidValue{
Key: StatusCodeTag,
Expected: StatusCodeSuccess,
Actual: statusCode.Value,
}
}
for _, assertion := range response.Assertions {
issuer = assertion.Issuer
if issuer == nil {
return ErrMissingElement{Tag: IssuerTag}
}
if sp.IdentityProviderIssuer != "" && assertion.Issuer.Value != sp.IdentityProviderIssuer {
return ErrInvalidValue{
Key: IssuerTag,
Expected: sp.IdentityProviderIssuer,
Actual: issuer.Value,
}
}
subject := assertion.Subject
if subject == nil {
return ErrMissingElement{Tag: SubjectTag}
}
subjectConfirmation := subject.SubjectConfirmation
if subjectConfirmation == nil {
return ErrMissingElement{Tag: SubjectConfirmationTag}
}
if subjectConfirmation.Method != SubjMethodBearer {
return ErrInvalidValue{
Reason: ReasonUnsupported,
Key: SubjectConfirmationTag,
Expected: SubjMethodBearer,
Actual: subjectConfirmation.Method,
}
}
subjectConfirmationData := subjectConfirmation.SubjectConfirmationData
if subjectConfirmationData == nil {
return ErrMissingElement{Tag: SubjectConfirmationDataTag}
}
if subjectConfirmationData.Recipient != sp.AssertionConsumerServiceURL {
return ErrInvalidValue{
Key: RecipientAttr,
Expected: sp.AssertionConsumerServiceURL,
Actual: subjectConfirmationData.Recipient,
}
}
if subjectConfirmationData.NotOnOrAfter == "" {
return ErrMissingElement{Tag: SubjectConfirmationDataTag, Attribute: NotOnOrAfterAttr}
}
notOnOrAfter, err := time.Parse(time.RFC3339, subjectConfirmationData.NotOnOrAfter)
if err != nil {
return ErrParsing{Tag: NotOnOrAfterAttr, Value: subjectConfirmationData.NotOnOrAfter, Type: "time.RFC3339"}
}
now := sp.Clock.Now()
if now.After(notOnOrAfter) {
return ErrInvalidValue{
Reason: ReasonExpired,
Key: NotOnOrAfterAttr,
Expected: now.Format(time.RFC3339),
Actual: subjectConfirmationData.NotOnOrAfter,
}
}
}
return nil
}
func (sp *SAMLServiceProvider) ValidateDecodedLogoutResponse(response *types.LogoutResponse) error {
err := sp.validateLogoutResponseAttributes(response)
if err != nil {
return err
}
issuer := response.Issuer
if issuer == nil {
// FIXME?: SAML Core 2.0 Section 3.2.2 has Response.Issuer as [Optional]
return ErrMissingElement{Tag: IssuerTag}
}
if sp.IdentityProviderIssuer != "" && response.Issuer.Value != sp.IdentityProviderIssuer {
return ErrInvalidValue{
Key: IssuerTag,
Expected: sp.IdentityProviderIssuer,
Actual: response.Issuer.Value,
}
}
status := response.Status
if status == nil {
return ErrMissingElement{Tag: StatusTag}
}
statusCode := status.StatusCode
if statusCode == nil {
return ErrMissingElement{Tag: StatusCodeTag}
}
if statusCode.Value != StatusCodeSuccess {
return ErrInvalidValue{
Key: StatusCodeTag,
Expected: StatusCodeSuccess,
Actual: statusCode.Value,
}
}
return nil
}
func (sp *SAMLServiceProvider) ValidateDecodedLogoutRequest(request *LogoutRequest) error {
err := sp.validateLogoutRequestAttributes(request)
if err != nil {
return err
}
issuer := request.Issuer
if issuer == nil {
// FIXME?: SAML Core 2.0 Section 3.2.2 has Response.Issuer as [Optional]
return ErrMissingElement{Tag: IssuerTag}
}
if sp.IdentityProviderIssuer != "" && request.Issuer.Value != sp.IdentityProviderIssuer {
return ErrInvalidValue{
Key: IssuerTag,
Expected: sp.IdentityProviderIssuer,
Actual: request.Issuer.Value,
}
}
return nil
}

74
vendor/github.com/mattermost/gosaml2/xml_constants.go generated vendored Normal file
View File

@ -0,0 +1,74 @@
// Copyright 2016 Russell Haering et al.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package saml2
const (
ResponseTag = "Response"
AssertionTag = "Assertion"
EncryptedAssertionTag = "EncryptedAssertion"
SubjectTag = "Subject"
NameIdTag = "NameID"
SubjectConfirmationTag = "SubjectConfirmation"
SubjectConfirmationDataTag = "SubjectConfirmationData"
AttributeStatementTag = "AttributeStatement"
AttributeValueTag = "AttributeValue"
ConditionsTag = "Conditions"
AudienceRestrictionTag = "AudienceRestriction"
AudienceTag = "Audience"
OneTimeUseTag = "OneTimeUse"
ProxyRestrictionTag = "ProxyRestriction"
IssuerTag = "Issuer"
StatusTag = "Status"
StatusCodeTag = "StatusCode"
)
const (
DestinationAttr = "Destination"
VersionAttr = "Version"
IdAttr = "ID"
MethodAttr = "Method"
RecipientAttr = "Recipient"
NameAttr = "Name"
NotBeforeAttr = "NotBefore"
NotOnOrAfterAttr = "NotOnOrAfter"
CountAttr = "Count"
)
const (
NameIdFormatPersistent = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
NameIdFormatTransient = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
NameIdFormatEmailAddress = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
NameIdFormatUnspecified = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"
NameIdFormatX509SubjectName = "urn:oasis:names:tc:SAML:1.1:nameid-format:x509SubjectName"
AuthnContextPasswordProtectedTransport = "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport"
AuthnPolicyMatchExact = "exact"
AuthnPolicyMatchMinimum = "minimum"
AuthnPolicyMatchMaximum = "maximum"
AuthnPolicyMatchBetter = "better"
StatusCodeSuccess = "urn:oasis:names:tc:SAML:2.0:status:Success"
StatusCodePartialLogout = "urn:oasis:names:tc:SAML:2.0:status:PartialLogout"
StatusCodeUnknownPrincipal = "urn:oasis:names:tc:SAML:2.0:status:UnknownPrincipal"
BindingHttpPost = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
BindingHttpRedirect = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
)
const (
SAMLAssertionNamespace = "urn:oasis:names:tc:SAML:2.0:assertion"
SAMLProtocolNamespace = "urn:oasis:names:tc:SAML:2.0:protocol"
)

0
vendor/github.com/mattermost/ldap/.gitignore generated vendored Normal file
View File

32
vendor/github.com/mattermost/ldap/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,32 @@
sudo: false
language: go
go:
- "1.5.x"
- "1.6.x"
- "1.7.x"
- "1.8.x"
- "1.9.x"
- "1.10.x"
- "1.11.x"
- "1.12.x"
- "1.13.x"
- tip
git:
depth: 1
matrix:
fast_finish: true
allow_failures:
- go: tip
go_import_path: github.com/go-ldap/ldap
install:
- go get github.com/go-asn1-ber/asn1-ber
- go get code.google.com/p/go.tools/cmd/cover || go get golang.org/x/tools/cmd/cover
- go get github.com/golang/lint/golint || go get golang.org/x/lint/golint || true
- go build -v ./...
script:
- make test
- make fmt
- make vet
- make lint

12
vendor/github.com/mattermost/ldap/CONTRIBUTING.md generated vendored Normal file
View File

@ -0,0 +1,12 @@
# Contribution Guidelines
We welcome contribution and improvements.
## Guiding Principles
To begin with here is a draft from an email exchange:
* take compatibility seriously (our semvers, compatibility with older go versions, etc)
* don't tag untested code for release
* beware of baking in implicit behavior based on other libraries/tools choices
* be as high-fidelity as possible in plumbing through LDAP data (don't mask errors or reduce power of someone using the library)

22
vendor/github.com/mattermost/ldap/LICENSE generated vendored Normal file
View File

@ -0,0 +1,22 @@
The MIT License (MIT)
Copyright (c) 2011-2015 Michael Mitton (mmitton@gmail.com)
Portions copyright (c) 2015-2016 go-ldap Authors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

82
vendor/github.com/mattermost/ldap/Makefile generated vendored Normal file
View File

@ -0,0 +1,82 @@
.PHONY: default install build test quicktest fmt vet lint
# List of all release tags "supported" by our current Go version
# E.g. ":go1.1:go1.2:go1.3:go1.4:go1.5:go1.6:go1.7:go1.8:go1.9:go1.10:go1.11:go1.12:"
GO_RELEASE_TAGS := $(shell go list -f ':{{join (context.ReleaseTags) ":"}}:' runtime)
# Only use the `-race` flag on newer versions of Go (version 1.3 and newer)
ifeq (,$(findstring :go1.3:,$(GO_RELEASE_TAGS)))
RACE_FLAG :=
else
RACE_FLAG := -race -cpu 1,2,4
endif
# Run `go vet` on Go 1.12 and newer. For Go 1.5-1.11, use `go tool vet`
ifneq (,$(findstring :go1.12:,$(GO_RELEASE_TAGS)))
GO_VET := go vet \
-atomic \
-bool \
-copylocks \
-nilfunc \
-printf \
-rangeloops \
-unreachable \
-unsafeptr \
-unusedresult \
.
else ifneq (,$(findstring :go1.5:,$(GO_RELEASE_TAGS)))
GO_VET := go tool vet \
-atomic \
-bool \
-copylocks \
-nilfunc \
-printf \
-shadow \
-rangeloops \
-unreachable \
-unsafeptr \
-unusedresult \
.
else
GO_VET := @echo "go vet skipped -- not supported on this version of Go"
endif
default: fmt vet lint build quicktest
install:
go get -t -v ./...
build:
go build -v ./...
test:
go test -v $(RACE_FLAG) -cover ./...
quicktest:
go test ./...
# Capture output and force failure when there is non-empty output
fmt:
@echo gofmt -l .
@OUTPUT=`gofmt -l . 2>&1`; \
if [ "$$OUTPUT" ]; then \
echo "gofmt must be run on the following files:"; \
echo "$$OUTPUT"; \
exit 1; \
fi
vet:
$(GO_VET)
# https://github.com/golang/lint
# go get github.com/golang/lint/golint
# Capture output and force failure when there is non-empty output
# Only run on go1.5+
lint:
@echo golint ./...
@OUTPUT=`command -v golint >/dev/null 2>&1 && golint ./... 2>&1`; \
if [ "$$OUTPUT" ]; then \
echo "golint errors:"; \
echo "$$OUTPUT"; \
exit 1; \
fi

61
vendor/github.com/mattermost/ldap/README.md generated vendored Normal file
View File

@ -0,0 +1,61 @@
[![Go Reference](https://pkg.go.dev/badge/github.com/mattermost/ldap.svg)](https://pkg.go.dev/github.com/mattermost/ldap)
[![Build Status](https://travis-ci.org/go-ldap/ldap.svg)](https://travis-ci.org/go-ldap/ldap)
# Basic LDAP v3 functionality for the GO programming language.
## Features:
- Connecting to LDAP server (non-TLS, TLS, STARTTLS)
- Binding to LDAP server
- Searching for entries
- Filter Compile / Decompile
- Paging Search Results
- Modify Requests / Responses
- Add Requests / Responses
- Delete Requests / Responses
- Modify DN Requests / Responses
## Examples:
- search
- modify
## Go Modules:
`go get github.com/go-ldap/ldap/v3`
As go-ldap was v2+ when Go Modules came out, updating to Go Modules would be considered a breaking change.
To maintain backwards compatability, we ultimately decided to use subfolders (as v3 was already a branch).
Whilst this duplicates the code, we can move toward implementing a backwards-compatible versioning system that allows for code reuse.
The alternative would be to increment the version number, however we believe that this would confuse users as v3 is in line with LDAPv3 (RFC-4511)
https://tools.ietf.org/html/rfc4511
For more info, please visit the pull request that updated to modules.
https://github.com/go-ldap/ldap/pull/247
To install with `GOMODULE111=off`, use `go get github.com/go-ldap/ldap`
https://golang.org/cmd/go/#hdr-Legacy_GOPATH_go_get
As always, we are looking for contributors with great ideas on how to best move forward.
## Contributing:
Bug reports and pull requests are welcome!
Before submitting a pull request, please make sure tests and verification scripts pass:
```
make all
```
To set up a pre-push hook to run the tests and verify scripts before pushing:
```
ln -s ../../.githooks/pre-push .git/hooks/pre-push
```
---
The Go gopher was designed by Renee French. (http://reneefrench.blogspot.com/)
The design is licensed under the Creative Commons 3.0 Attributions license.
Read this article for more details: http://blog.golang.org/gopher

100
vendor/github.com/mattermost/ldap/add.go generated vendored Normal file
View File

@ -0,0 +1,100 @@
//
// https://tools.ietf.org/html/rfc4511
//
// AddRequest ::= [APPLICATION 8] SEQUENCE {
// entry LDAPDN,
// attributes AttributeList }
//
// AttributeList ::= SEQUENCE OF attribute Attribute
package ldap
import (
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
// Attribute represents an LDAP attribute
type Attribute struct {
// Type is the name of the LDAP attribute
Type string
// Vals are the LDAP attribute values
Vals []string
}
func (a *Attribute) encode() *ber.Packet {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, a.Type, "Type"))
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
for _, value := range a.Vals {
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
}
seq.AppendChild(set)
return seq
}
// AddRequest represents an LDAP AddRequest operation
type AddRequest struct {
// DN identifies the entry being added
DN string
// Attributes list the attributes of the new entry
Attributes []Attribute
// Controls hold optional controls to send with the request
Controls []Control
}
func (req *AddRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationAddRequest, nil, "Add Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
attributes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
for _, attribute := range req.Attributes {
attributes.AppendChild(attribute.encode())
}
pkt.AppendChild(attributes)
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// Attribute adds an attribute with the given type and values
func (req *AddRequest) Attribute(attrType string, attrVals []string) {
req.Attributes = append(req.Attributes, Attribute{Type: attrType, Vals: attrVals})
}
// NewAddRequest returns an AddRequest for the given DN, with no attributes
func NewAddRequest(dn string, controls []Control) *AddRequest {
return &AddRequest{
DN: dn,
Controls: controls,
}
}
// Add performs the given AddRequest
func (l *Conn) Add(addRequest *AddRequest) error {
msgCtx, err := l.doRequest(addRequest)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}
tag := packet.Children[1].Tag
if tag == ApplicationAddResponse {
err := GetLDAPError(packet)
if err != nil {
return err
}
} else {
l.Debug.Log("Unexpected Response", mlog.Uint("tag", tag))
}
return nil
}

152
vendor/github.com/mattermost/ldap/bind.go generated vendored Normal file
View File

@ -0,0 +1,152 @@
package ldap
import (
"errors"
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
// SimpleBindRequest represents a username/password bind operation
type SimpleBindRequest struct {
// Username is the name of the Directory object that the client wishes to bind as
Username string
// Password is the credentials to bind with
Password string
// Controls are optional controls to send with the bind request
Controls []Control
// AllowEmptyPassword sets whether the client allows binding with an empty password
// (normally used for unauthenticated bind).
AllowEmptyPassword bool
}
// SimpleBindResult contains the response from the server
type SimpleBindResult struct {
Controls []Control
}
// NewSimpleBindRequest returns a bind request
func NewSimpleBindRequest(username string, password string, controls []Control) *SimpleBindRequest {
return &SimpleBindRequest{
Username: username,
Password: password,
Controls: controls,
AllowEmptyPassword: false,
}
}
func (req *SimpleBindRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Username, "User Name"))
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.Password, "Password"))
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// SimpleBind performs the simple bind operation defined in the given request
func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error) {
if simpleBindRequest.Password == "" && !simpleBindRequest.AllowEmptyPassword {
return nil, NewError(ErrorEmptyPassword, errors.New("ldap: empty password not allowed by the client"))
}
msgCtx, err := l.doRequest(simpleBindRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
result := &SimpleBindResult{
Controls: make([]Control, 0),
}
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {
decodedChild, decodeErr := DecodeControl(child)
if decodeErr != nil {
return nil, fmt.Errorf("failed to decode child control: %s", decodeErr)
}
result.Controls = append(result.Controls, decodedChild)
}
}
err = GetLDAPError(packet)
return result, err
}
// Bind performs a bind with the given username and password.
//
// It does not allow unauthenticated bind (i.e. empty password). Use the UnauthenticatedBind method
// for that.
func (l *Conn) Bind(username, password string) error {
req := &SimpleBindRequest{
Username: username,
Password: password,
AllowEmptyPassword: false,
}
_, err := l.SimpleBind(req)
return err
}
// UnauthenticatedBind performs an unauthenticated bind.
//
// A username may be provided for trace (e.g. logging) purpose only, but it is normally not
// authenticated or otherwise validated by the LDAP server.
//
// See https://tools.ietf.org/html/rfc4513#section-5.1.2 .
// See https://tools.ietf.org/html/rfc4513#section-6.3.1 .
func (l *Conn) UnauthenticatedBind(username string) error {
req := &SimpleBindRequest{
Username: username,
Password: "",
AllowEmptyPassword: true,
}
_, err := l.SimpleBind(req)
return err
}
var externalBindRequest = requestFunc(func(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "User Name"))
saslAuth := ber.Encode(ber.ClassContext, ber.TypeConstructed, 3, "", "authentication")
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "EXTERNAL", "SASL Mech"))
saslAuth.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "SASL Cred"))
pkt.AppendChild(saslAuth)
envelope.AppendChild(pkt)
return nil
})
// ExternalBind performs SASL/EXTERNAL authentication.
//
// Use ldap.DialURL("ldapi://") to connect to the Unix socket before ExternalBind.
//
// See https://tools.ietf.org/html/rfc4422#appendix-A
func (l *Conn) ExternalBind() error {
msgCtx, err := l.doRequest(externalBindRequest)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}
return GetLDAPError(packet)
}

30
vendor/github.com/mattermost/ldap/client.go generated vendored Normal file
View File

@ -0,0 +1,30 @@
package ldap
import (
"crypto/tls"
"time"
)
// Client knows how to interact with an LDAP server
type Client interface {
Start()
StartTLS(*tls.Config) error
Close()
SetTimeout(time.Duration)
Bind(username, password string) error
UnauthenticatedBind(username string) error
SimpleBind(*SimpleBindRequest) (*SimpleBindResult, error)
ExternalBind() error
Add(*AddRequest) error
Del(*DelRequest) error
Modify(*ModifyRequest) error
ModifyDN(*ModifyDNRequest) error
Compare(dn, attribute, value string) (bool, error)
PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error)
Search(*SearchRequest) (*SearchResult, error)
SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error)
}

80
vendor/github.com/mattermost/ldap/compare.go generated vendored Normal file
View File

@ -0,0 +1,80 @@
// File contains Compare functionality
//
// https://tools.ietf.org/html/rfc4511
//
// CompareRequest ::= [APPLICATION 14] SEQUENCE {
// entry LDAPDN,
// ava AttributeValueAssertion }
//
// AttributeValueAssertion ::= SEQUENCE {
// attributeDesc AttributeDescription,
// assertionValue AssertionValue }
//
// AttributeDescription ::= LDAPString
// -- Constrained to <attributedescription>
// -- [RFC4512]
//
// AttributeValue ::= OCTET STRING
//
package ldap
import (
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
// CompareRequest represents an LDAP CompareRequest operation.
type CompareRequest struct {
DN string
Attribute string
Value string
}
func (req *CompareRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationCompareRequest, nil, "Compare Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
ava := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "AttributeValueAssertion")
ava.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Attribute, "AttributeDesc"))
ava.AppendChild(ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.Value, "AssertionValue"))
pkt.AppendChild(ava)
envelope.AppendChild(pkt)
return nil
}
// Compare checks to see if the attribute of the dn matches value. Returns true if it does otherwise
// false with any error that occurs if any.
func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
msgCtx, err := l.doRequest(&CompareRequest{
DN: dn,
Attribute: attribute,
Value: value})
if err != nil {
return false, err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return false, err
}
if packet.Children[1].Tag == ApplicationCompareResponse {
err := GetLDAPError(packet)
switch {
case IsErrorWithCode(err, LDAPResultCompareTrue):
return true, nil
case IsErrorWithCode(err, LDAPResultCompareFalse):
return false, nil
default:
return false, err
}
}
return false, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag)
}

522
vendor/github.com/mattermost/ldap/conn.go generated vendored Normal file
View File

@ -0,0 +1,522 @@
package ldap
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"sync"
"sync/atomic"
"time"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
const (
// MessageQuit causes the processMessages loop to exit
MessageQuit = 0
// MessageRequest sends a request to the server
MessageRequest = 1
// MessageResponse receives a response from the server
MessageResponse = 2
// MessageFinish indicates the client considers a particular message ID to be finished
MessageFinish = 3
// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
MessageTimeout = 4
)
const (
// DefaultLdapPort default ldap port for pure TCP connection
DefaultLdapPort = "389"
// DefaultLdapsPort default ldap port for SSL connection
DefaultLdapsPort = "636"
)
// PacketResponse contains the packet or error encountered reading a response
type PacketResponse struct {
// Packet is the packet read from the server
Packet *ber.Packet
// Error is an error encountered while reading
Error error
}
// ReadPacket returns the packet or an error
func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
}
return pr.Packet, pr.Error
}
type messageContext struct {
id int64
// close(done) should only be called from finishMessage()
done chan struct{}
// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
responses chan *PacketResponse
}
// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
select {
case msgCtx.responses <- packet:
// Successfully sent packet to message handler.
case <-msgCtx.done:
// The request handler is done and will not receive more
// packets.
}
}
type messagePacket struct {
Op int
MessageID int64
Packet *ber.Packet
Context *messageContext
}
type sendMessageFlags uint
const (
startTLS sendMessageFlags = 1 << iota
)
// Conn represents an LDAP Connection
type Conn struct {
// requestTimeout is loaded atomically
// so we need to ensure 64-bit alignment on 32-bit platforms.
requestTimeout int64
conn net.Conn
isTLS bool
closing uint32
closeErr atomic.Value
isStartingTLS bool
Debug debugging
chanConfirm chan struct{}
messageContexts map[int64]*messageContext
chanMessage chan *messagePacket
chanMessageID chan int64
wgClose sync.WaitGroup
outstandingRequests uint
messageMutex sync.Mutex
}
var _ Client = &Conn{}
// DefaultTimeout is a package-level variable that sets the timeout value
// used for the Dial and DialTLS methods.
//
// WARNING: since this is a package-level variable, setting this value from
// multiple places will probably result in undesired behaviour.
var DefaultTimeout = 60 * time.Second
// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
func Dial(network, addr string) (*Conn, error) {
c, err := net.DialTimeout(network, addr, DefaultTimeout)
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
conn := NewConn(c, false)
return conn, nil
}
// DialTLS connects to the given address on the given network using tls.Dial
// and then returns a new Conn for the connection.
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
conn := NewConn(c, true)
return conn, nil
}
// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps://
// or ldap:// specified as protocol. On success a new Conn for the connection
// is returned.
func DialURL(addr string) (*Conn, error) {
lurl, err := url.Parse(addr)
if err != nil {
return nil, NewError(ErrorNetwork, err)
}
host, port, err := net.SplitHostPort(lurl.Host)
if err != nil {
// we asume that error is due to missing port
host = lurl.Host
port = ""
}
switch lurl.Scheme {
case "ldapi":
if lurl.Path == "" || lurl.Path == "/" {
lurl.Path = "/var/run/slapd/ldapi"
}
return Dial("unix", lurl.Path)
case "ldap":
if port == "" {
port = DefaultLdapPort
}
return Dial("tcp", net.JoinHostPort(host, port))
case "ldaps":
if port == "" {
port = DefaultLdapsPort
}
tlsConf := &tls.Config{
ServerName: host,
}
return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf)
}
return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme))
}
// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn, isTLS bool) *Conn {
return &Conn{
conn: conn,
chanConfirm: make(chan struct{}),
chanMessageID: make(chan int64),
chanMessage: make(chan *messagePacket, 10),
messageContexts: map[int64]*messageContext{},
requestTimeout: 0,
isTLS: isTLS,
}
}
// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
l.wgClose.Add(1)
go l.reader()
go l.processMessages()
}
// IsClosing returns whether or not we're currently closing.
func (l *Conn) IsClosing() bool {
return atomic.LoadUint32(&l.closing) == 1
}
// setClosing sets the closing value to true
func (l *Conn) setClosing() bool {
return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
}
// Close closes the connection.
func (l *Conn) Close() {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
if l.setClosing() {
l.Debug.Log("Sending quit message and waiting for confirmation")
l.chanMessage <- &messagePacket{Op: MessageQuit}
<-l.chanConfirm
close(l.chanMessage)
l.Debug.Log("Closing network connection")
if err := l.conn.Close(); err != nil {
l.Debug.Log("Error closing network connection", mlog.Err(err))
}
l.wgClose.Done()
}
l.wgClose.Wait()
}
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
if timeout > 0 {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}
}
// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
if messageID, ok := <-l.chanMessageID; ok {
return messageID
}
return 0
}
// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func (l *Conn) StartTLS(config *tls.Config) error {
if l.isTLS {
return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
}
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
packet.AppendChild(request)
l.Debug.Log("Sending StartTLS packet", PacketToField(packet))
msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)
l.Debug.Log("Waiting for StartTLS response", mlog.Int("id", msgCtx.id))
packetResponse, ok := <-msgCtx.responses
if !ok {
return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
}
packet, err = packetResponse.ReadPacket()
if l.Debug.Enabled() {
if err := addLDAPDescriptions(packet); err != nil {
l.Close()
return err
}
l.Debug.Log("Got response %p", mlog.Err(err), mlog.Int("id", msgCtx.id), PacketToField(packet), mlog.Err(err))
}
if err != nil {
return err
}
if err := GetLDAPError(packet); err == nil {
conn := tls.Client(l.conn, config)
if connErr := conn.Handshake(); connErr != nil {
l.Close()
return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
}
l.isTLS = true
l.conn = conn
} else {
return err
}
go l.reader()
return nil
}
// TLSConnectionState returns the client's TLS connection state.
// The return values are their zero values if StartTLS did
// not succeed.
func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
tc, ok := l.conn.(*tls.Conn)
if !ok {
return
}
return tc.ConnectionState(), true
}
func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
return l.sendMessageWithFlags(packet, 0)
}
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
if l.IsClosing() {
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
}
l.messageMutex.Lock()
if l.isStartingTLS {
l.messageMutex.Unlock()
return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
}
if flags&startTLS != 0 {
if l.outstandingRequests != 0 {
l.messageMutex.Unlock()
return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
}
l.isStartingTLS = true
}
l.outstandingRequests++
l.messageMutex.Unlock()
responses := make(chan *PacketResponse)
messageID := packet.Children[0].Value.(int64)
message := &messagePacket{
Op: MessageRequest,
MessageID: messageID,
Packet: packet,
Context: &messageContext{
id: messageID,
done: make(chan struct{}),
responses: responses,
},
}
l.sendProcessMessage(message)
return message.Context, nil
}
func (l *Conn) finishMessage(msgCtx *messageContext) {
close(msgCtx.done)
if l.IsClosing() {
return
}
l.messageMutex.Lock()
l.outstandingRequests--
if l.isStartingTLS {
l.isStartingTLS = false
}
l.messageMutex.Unlock()
message := &messagePacket{
Op: MessageFinish,
MessageID: msgCtx.id,
}
l.sendProcessMessage(message)
}
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
l.messageMutex.Lock()
defer l.messageMutex.Unlock()
if l.IsClosing() {
return false
}
l.chanMessage <- message
return true
}
func (l *Conn) processMessages() {
defer func() {
if r := recover(); r != nil {
l.Debug.Log("Recovered panic in processMessages", mlog.Any("panic", r))
}
for messageID, msgCtx := range l.messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
}
l.Debug.Log("Closing channel for MessageID", mlog.Int("message_id", messageID))
close(msgCtx.responses)
delete(l.messageContexts, messageID)
}
close(l.chanMessageID)
close(l.chanConfirm)
}()
var messageID int64 = 1
for {
select {
case l.chanMessageID <- messageID:
messageID++
case message := <-l.chanMessage:
switch message.Op {
case MessageQuit:
l.Debug.Log("Quit message received: Shutting down")
return
case MessageRequest:
// Add to message list and write to network
buf := message.Packet.Bytes()
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Log("Error Sending Message", mlog.Err(err))
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
close(message.Context.responses)
break
}
// Only add to messageContexts if we were able to
// successfully write the message.
l.messageContexts[message.MessageID] = message.Context
// Add timeout if defined
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
if requestTimeout > 0 {
go func() {
defer func() {
if r := recover(); r != nil {
l.Debug.Log("Recovered panic in RequestTimeout", mlog.Any("panic", r))
}
}()
time.Sleep(requestTimeout)
timeoutMessage := &messagePacket{
Op: MessageTimeout,
MessageID: message.MessageID,
}
l.sendProcessMessage(timeoutMessage)
}()
}
case MessageResponse:
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
} else {
l.Debug.Log(
"Received unexpected message",
mlog.Int("message_id", message.MessageID),
mlog.Bool("is_closing", l.IsClosing()),
PacketToField(message.Packet),
)
}
case MessageTimeout:
// Handle the timeout by closing the channel
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Log("Receiving message timeout", mlog.Int("message_id", message.MessageID))
msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
case MessageFinish:
l.Debug.Log("Finished message", mlog.Int("message_id", message.MessageID))
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
}
}
}
}
func (l *Conn) reader() {
cleanstop := false
defer func() {
if r := recover(); r != nil {
l.Debug.Log("Recovered panic in reader", mlog.Any("panic", r))
}
if !cleanstop {
l.Close()
}
}()
for {
if cleanstop {
l.Debug.Log("Reader clean stopping (without closing the connection)")
return
}
packet, err := ber.ReadPacket(l.conn)
if err != nil {
// A read error is expected here if we are closing the connection...
if !l.IsClosing() {
l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
l.Debug.Log("Reader error", mlog.Err(err))
}
return
}
if err := addLDAPDescriptions(packet); err != nil {
l.Debug.Log("Descriptions error", mlog.Err(err))
}
if len(packet.Children) == 0 {
l.Debug.Log("Received bad ldap packet")
continue
}
l.messageMutex.Lock()
if l.isStartingTLS {
cleanstop = true
}
l.messageMutex.Unlock()
message := &messagePacket{
Op: MessageResponse,
MessageID: packet.Children[0].Value.(int64),
Packet: packet,
}
if !l.sendProcessMessage(message) {
return
}
}
}

499
vendor/github.com/mattermost/ldap/control.go generated vendored Normal file
View File

@ -0,0 +1,499 @@
package ldap
import (
"fmt"
"strconv"
"github.com/go-asn1-ber/asn1-ber"
)
const (
// ControlTypePaging - https://www.ietf.org/rfc/rfc2696.txt
ControlTypePaging = "1.2.840.113556.1.4.319"
// ControlTypeBeheraPasswordPolicy - https://tools.ietf.org/html/draft-behera-ldap-password-policy-10
ControlTypeBeheraPasswordPolicy = "1.3.6.1.4.1.42.2.27.8.5.1"
// ControlTypeVChuPasswordMustChange - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
ControlTypeVChuPasswordMustChange = "2.16.840.1.113730.3.4.4"
// ControlTypeVChuPasswordWarning - https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
ControlTypeVChuPasswordWarning = "2.16.840.1.113730.3.4.5"
// ControlTypeManageDsaIT - https://tools.ietf.org/html/rfc3296
ControlTypeManageDsaIT = "2.16.840.1.113730.3.4.2"
// ControlTypeMicrosoftNotification - https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx
ControlTypeMicrosoftNotification = "1.2.840.113556.1.4.528"
// ControlTypeMicrosoftShowDeleted - https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx
ControlTypeMicrosoftShowDeleted = "1.2.840.113556.1.4.417"
)
// ControlTypeMap maps controls to text descriptions
var ControlTypeMap = map[string]string{
ControlTypePaging: "Paging",
ControlTypeBeheraPasswordPolicy: "Password Policy - Behera Draft",
ControlTypeManageDsaIT: "Manage DSA IT",
ControlTypeMicrosoftNotification: "Change Notification - Microsoft",
ControlTypeMicrosoftShowDeleted: "Show Deleted Objects - Microsoft",
}
// Control defines an interface controls provide to encode and describe themselves
type Control interface {
// GetControlType returns the OID
GetControlType() string
// Encode returns the ber packet representation
Encode() *ber.Packet
// String returns a human-readable description
String() string
}
// ControlString implements the Control interface for simple controls
type ControlString struct {
ControlType string
Criticality bool
ControlValue string
}
// GetControlType returns the OID
func (c *ControlString) GetControlType() string {
return c.ControlType
}
// Encode returns the ber packet representation
func (c *ControlString) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")"))
if c.Criticality {
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
}
if c.ControlValue != "" {
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, string(c.ControlValue), "Control Value"))
}
return packet
}
// String returns a human-readable description
func (c *ControlString) String() string {
return fmt.Sprintf("Control Type: %s (%q) Criticality: %t Control Value: %s", ControlTypeMap[c.ControlType], c.ControlType, c.Criticality, c.ControlValue)
}
// ControlPaging implements the paging control described in https://www.ietf.org/rfc/rfc2696.txt
type ControlPaging struct {
// PagingSize indicates the page size
PagingSize uint32
// Cookie is an opaque value returned by the server to track a paging cursor
Cookie []byte
}
// GetControlType returns the OID
func (c *ControlPaging) GetControlType() string {
return ControlTypePaging
}
// Encode returns the ber packet representation
func (c *ControlPaging) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")"))
p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Control Value (Paging)")
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value")
seq.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, int64(c.PagingSize), "Paging Size"))
cookie := ber.Encode(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, nil, "Cookie")
cookie.Value = c.Cookie
cookie.Data.Write(c.Cookie)
seq.AppendChild(cookie)
p2.AppendChild(seq)
packet.AppendChild(p2)
return packet
}
// String returns a human-readable description
func (c *ControlPaging) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t PagingSize: %d Cookie: %q",
ControlTypeMap[ControlTypePaging],
ControlTypePaging,
false,
c.PagingSize,
c.Cookie)
}
// SetCookie stores the given cookie in the paging control
func (c *ControlPaging) SetCookie(cookie []byte) {
c.Cookie = cookie
}
// ControlBeheraPasswordPolicy implements the control described in https://tools.ietf.org/html/draft-behera-ldap-password-policy-10
type ControlBeheraPasswordPolicy struct {
// Expire contains the number of seconds before a password will expire
Expire int64
// Grace indicates the remaining number of times a user will be allowed to authenticate with an expired password
Grace int64
// Error indicates the error code
Error int8
// ErrorString is a human readable error
ErrorString string
}
// GetControlType returns the OID
func (c *ControlBeheraPasswordPolicy) GetControlType() string {
return ControlTypeBeheraPasswordPolicy
}
// Encode returns the ber packet representation
func (c *ControlBeheraPasswordPolicy) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeBeheraPasswordPolicy, "Control Type ("+ControlTypeMap[ControlTypeBeheraPasswordPolicy]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlBeheraPasswordPolicy) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Expire: %d Grace: %d Error: %d, ErrorString: %s",
ControlTypeMap[ControlTypeBeheraPasswordPolicy],
ControlTypeBeheraPasswordPolicy,
false,
c.Expire,
c.Grace,
c.Error,
c.ErrorString)
}
// ControlVChuPasswordMustChange implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
type ControlVChuPasswordMustChange struct {
// MustChange indicates if the password is required to be changed
MustChange bool
}
// GetControlType returns the OID
func (c *ControlVChuPasswordMustChange) GetControlType() string {
return ControlTypeVChuPasswordMustChange
}
// Encode returns the ber packet representation
func (c *ControlVChuPasswordMustChange) Encode() *ber.Packet {
return nil
}
// String returns a human-readable description
func (c *ControlVChuPasswordMustChange) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t MustChange: %v",
ControlTypeMap[ControlTypeVChuPasswordMustChange],
ControlTypeVChuPasswordMustChange,
false,
c.MustChange)
}
// ControlVChuPasswordWarning implements the control described in https://tools.ietf.org/html/draft-vchu-ldap-pwd-policy-00
type ControlVChuPasswordWarning struct {
// Expire indicates the time in seconds until the password expires
Expire int64
}
// GetControlType returns the OID
func (c *ControlVChuPasswordWarning) GetControlType() string {
return ControlTypeVChuPasswordWarning
}
// Encode returns the ber packet representation
func (c *ControlVChuPasswordWarning) Encode() *ber.Packet {
return nil
}
// String returns a human-readable description
func (c *ControlVChuPasswordWarning) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t Expire: %b",
ControlTypeMap[ControlTypeVChuPasswordWarning],
ControlTypeVChuPasswordWarning,
false,
c.Expire)
}
// ControlManageDsaIT implements the control described in https://tools.ietf.org/html/rfc3296
type ControlManageDsaIT struct {
// Criticality indicates if this control is required
Criticality bool
}
// GetControlType returns the OID
func (c *ControlManageDsaIT) GetControlType() string {
return ControlTypeManageDsaIT
}
// Encode returns the ber packet representation
func (c *ControlManageDsaIT) Encode() *ber.Packet {
//FIXME
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeManageDsaIT, "Control Type ("+ControlTypeMap[ControlTypeManageDsaIT]+")"))
if c.Criticality {
packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, c.Criticality, "Criticality"))
}
return packet
}
// String returns a human-readable description
func (c *ControlManageDsaIT) String() string {
return fmt.Sprintf(
"Control Type: %s (%q) Criticality: %t",
ControlTypeMap[ControlTypeManageDsaIT],
ControlTypeManageDsaIT,
c.Criticality)
}
// NewControlManageDsaIT returns a ControlManageDsaIT control
func NewControlManageDsaIT(Criticality bool) *ControlManageDsaIT {
return &ControlManageDsaIT{Criticality: Criticality}
}
// ControlMicrosoftNotification implements the control described in https://msdn.microsoft.com/en-us/library/aa366983(v=vs.85).aspx
type ControlMicrosoftNotification struct{}
// GetControlType returns the OID
func (c *ControlMicrosoftNotification) GetControlType() string {
return ControlTypeMicrosoftNotification
}
// Encode returns the ber packet representation
func (c *ControlMicrosoftNotification) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftNotification, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftNotification]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlMicrosoftNotification) String() string {
return fmt.Sprintf(
"Control Type: %s (%q)",
ControlTypeMap[ControlTypeMicrosoftNotification],
ControlTypeMicrosoftNotification)
}
// NewControlMicrosoftNotification returns a ControlMicrosoftNotification control
func NewControlMicrosoftNotification() *ControlMicrosoftNotification {
return &ControlMicrosoftNotification{}
}
// ControlMicrosoftShowDeleted implements the control described in https://msdn.microsoft.com/en-us/library/aa366989(v=vs.85).aspx
type ControlMicrosoftShowDeleted struct{}
// GetControlType returns the OID
func (c *ControlMicrosoftShowDeleted) GetControlType() string {
return ControlTypeMicrosoftShowDeleted
}
// Encode returns the ber packet representation
func (c *ControlMicrosoftShowDeleted) Encode() *ber.Packet {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control")
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, ControlTypeMicrosoftShowDeleted, "Control Type ("+ControlTypeMap[ControlTypeMicrosoftShowDeleted]+")"))
return packet
}
// String returns a human-readable description
func (c *ControlMicrosoftShowDeleted) String() string {
return fmt.Sprintf(
"Control Type: %s (%q)",
ControlTypeMap[ControlTypeMicrosoftShowDeleted],
ControlTypeMicrosoftShowDeleted)
}
// NewControlMicrosoftShowDeleted returns a ControlMicrosoftShowDeleted control
func NewControlMicrosoftShowDeleted() *ControlMicrosoftShowDeleted {
return &ControlMicrosoftShowDeleted{}
}
// FindControl returns the first control of the given type in the list, or nil
func FindControl(controls []Control, controlType string) Control {
for _, c := range controls {
if c.GetControlType() == controlType {
return c
}
}
return nil
}
// DecodeControl returns a control read from the given packet, or nil if no recognized control can be made
func DecodeControl(packet *ber.Packet) (Control, error) {
var (
ControlType = ""
Criticality = false
value *ber.Packet
)
switch len(packet.Children) {
case 0:
// at least one child is required for control type
return nil, fmt.Errorf("at least one child is required for control type")
case 1:
// just type, no criticality or value
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
ControlType = packet.Children[0].Value.(string)
case 2:
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
ControlType = packet.Children[0].Value.(string)
// Children[1] could be criticality or value (both are optional)
// duck-type on whether this is a boolean
if _, ok := packet.Children[1].Value.(bool); ok {
packet.Children[1].Description = "Criticality"
Criticality = packet.Children[1].Value.(bool)
} else {
packet.Children[1].Description = "Control Value"
value = packet.Children[1]
}
case 3:
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
ControlType = packet.Children[0].Value.(string)
packet.Children[1].Description = "Criticality"
Criticality = packet.Children[1].Value.(bool)
packet.Children[2].Description = "Control Value"
value = packet.Children[2]
default:
// more than 3 children is invalid
return nil, fmt.Errorf("more than 3 children is invalid for controls")
}
switch ControlType {
case ControlTypeManageDsaIT:
return NewControlManageDsaIT(Criticality), nil
case ControlTypePaging:
value.Description += " (Paging)"
c := new(ControlPaging)
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
value = value.Children[0]
value.Description = "Search Control Value"
value.Children[0].Description = "Paging Size"
value.Children[1].Description = "Cookie"
c.PagingSize = uint32(value.Children[0].Value.(int64))
c.Cookie = value.Children[1].Data.Bytes()
value.Children[1].Value = c.Cookie
return c, nil
case ControlTypeBeheraPasswordPolicy:
value.Description += " (Password Policy - Behera)"
c := NewControlBeheraPasswordPolicy()
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
sequence := value.Children[0]
for _, child := range sequence.Children {
if child.Tag == 0 {
//Warning
warningPacket := child.Children[0]
packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int64)
if ok {
if warningPacket.Tag == 0 {
//timeBeforeExpiration
c.Expire = val
warningPacket.Value = c.Expire
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
c.Grace = val
warningPacket.Value = c.Grace
}
}
} else if child.Tag == 1 {
// Error
packet, err := ber.DecodePacketErr(child.Data.Bytes())
if err != nil {
return nil, fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int8)
if !ok {
// what to do?
val = -1
}
c.Error = val
child.Value = c.Error
c.ErrorString = BeheraPasswordPolicyErrorMap[c.Error]
}
}
return c, nil
case ControlTypeVChuPasswordMustChange:
c := &ControlVChuPasswordMustChange{MustChange: true}
return c, nil
case ControlTypeVChuPasswordWarning:
c := &ControlVChuPasswordWarning{Expire: -1}
expireStr := ber.DecodeString(value.Data.Bytes())
expire, err := strconv.ParseInt(expireStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse value as int: %s", err)
}
c.Expire = expire
value.Value = c.Expire
return c, nil
case ControlTypeMicrosoftNotification:
return NewControlMicrosoftNotification(), nil
case ControlTypeMicrosoftShowDeleted:
return NewControlMicrosoftShowDeleted(), nil
default:
c := new(ControlString)
c.ControlType = ControlType
c.Criticality = Criticality
if value != nil {
c.ControlValue = value.Value.(string)
}
return c, nil
}
}
// NewControlString returns a generic control
func NewControlString(controlType string, criticality bool, controlValue string) *ControlString {
return &ControlString{
ControlType: controlType,
Criticality: criticality,
ControlValue: controlValue,
}
}
// NewControlPaging returns a paging control
func NewControlPaging(pagingSize uint32) *ControlPaging {
return &ControlPaging{PagingSize: pagingSize}
}
// NewControlBeheraPasswordPolicy returns a ControlBeheraPasswordPolicy
func NewControlBeheraPasswordPolicy() *ControlBeheraPasswordPolicy {
return &ControlBeheraPasswordPolicy{
Expire: -1,
Grace: -1,
Error: -1,
}
}
func encodeControls(controls []Control) *ber.Packet {
packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls")
for _, control := range controls {
packet.AppendChild(control.Encode())
}
return packet
}

49
vendor/github.com/mattermost/ldap/debug.go generated vendored Normal file
View File

@ -0,0 +1,49 @@
package ldap
import (
"bytes"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
type debugging struct {
logger mlog.LoggerIFace
levels []mlog.Level
}
// Enable controls debugging mode.
func (debug *debugging) Enable(logger mlog.LoggerIFace, levels ...mlog.Level) {
*debug = debugging{
logger: logger,
levels: levels,
}
}
func (debug debugging) Enabled() bool {
return debug.logger != nil
}
// Log writes debug output.
func (debug debugging) Log(msg string, fields ...mlog.Field) {
if debug.Enabled() {
debug.logger.LogM(debug.levels, msg, fields...)
}
}
type Packet ber.Packet
func (p Packet) LogClone() any {
bp := ber.Packet(p)
var b bytes.Buffer
ber.WritePacket(&b, &bp)
return b.String()
}
func PacketToField(packet *ber.Packet) mlog.Field {
if packet == nil {
return mlog.Any("packet", nil)
}
return mlog.Any("packet", Packet(*packet))
}

65
vendor/github.com/mattermost/ldap/del.go generated vendored Normal file
View File

@ -0,0 +1,65 @@
//
// https://tools.ietf.org/html/rfc4511
//
// DelRequest ::= [APPLICATION 10] LDAPDN
package ldap
import (
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
// DelRequest implements an LDAP deletion request
type DelRequest struct {
// DN is the name of the directory entry to delete
DN string
// Controls hold optional controls to send with the request
Controls []Control
}
func (req *DelRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypePrimitive, ApplicationDelRequest, req.DN, "Del Request")
pkt.Data.Write([]byte(req.DN))
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// NewDelRequest creates a delete request for the given DN and controls
func NewDelRequest(DN string, Controls []Control) *DelRequest {
return &DelRequest{
DN: DN,
Controls: Controls,
}
}
// Del executes the given delete request
func (l *Conn) Del(delRequest *DelRequest) error {
msgCtx, err := l.doRequest(delRequest)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}
tag := packet.Children[1].Tag
if tag == ApplicationDelResponse {
err := GetLDAPError(packet)
if err != nil {
return err
}
} else {
l.Debug.Log("Unexpected Response tag", mlog.Uint("tag", tag))
}
return nil
}

247
vendor/github.com/mattermost/ldap/dn.go generated vendored Normal file
View File

@ -0,0 +1,247 @@
// File contains DN parsing functionality
//
// https://tools.ietf.org/html/rfc4514
//
// distinguishedName = [ relativeDistinguishedName
// *( COMMA relativeDistinguishedName ) ]
// relativeDistinguishedName = attributeTypeAndValue
// *( PLUS attributeTypeAndValue )
// attributeTypeAndValue = attributeType EQUALS attributeValue
// attributeType = descr / numericoid
// attributeValue = string / hexstring
//
// ; The following characters are to be escaped when they appear
// ; in the value to be encoded: ESC, one of <escaped>, leading
// ; SHARP or SPACE, trailing SPACE, and NULL.
// string = [ ( leadchar / pair ) [ *( stringchar / pair )
// ( trailchar / pair ) ] ]
//
// leadchar = LUTF1 / UTFMB
// LUTF1 = %x01-1F / %x21 / %x24-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// trailchar = TUTF1 / UTFMB
// TUTF1 = %x01-1F / %x21 / %x23-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// stringchar = SUTF1 / UTFMB
// SUTF1 = %x01-21 / %x23-2A / %x2D-3A /
// %x3D / %x3F-5B / %x5D-7F
//
// pair = ESC ( ESC / special / hexpair )
// special = escaped / SPACE / SHARP / EQUALS
// escaped = DQUOTE / PLUS / COMMA / SEMI / LANGLE / RANGLE
// hexstring = SHARP 1*hexpair
// hexpair = HEX HEX
//
// where the productions <descr>, <numericoid>, <COMMA>, <DQUOTE>,
// <EQUALS>, <ESC>, <HEX>, <LANGLE>, <NULL>, <PLUS>, <RANGLE>, <SEMI>,
// <SPACE>, <SHARP>, and <UTFMB> are defined in [RFC4512].
//
package ldap
import (
"bytes"
enchex "encoding/hex"
"errors"
"fmt"
"strings"
"github.com/go-asn1-ber/asn1-ber"
)
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514
type AttributeTypeAndValue struct {
// Type is the attribute type
Type string
// Value is the attribute value
Value string
}
// RelativeDN represents a relativeDistinguishedName from https://tools.ietf.org/html/rfc4514
type RelativeDN struct {
Attributes []*AttributeTypeAndValue
}
// DN represents a distinguishedName from https://tools.ietf.org/html/rfc4514
type DN struct {
RDNs []*RelativeDN
}
// ParseDN returns a distinguishedName or an error
func ParseDN(str string) (*DN, error) {
dn := new(DN)
dn.RDNs = make([]*RelativeDN, 0)
rdn := new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
buffer := bytes.Buffer{}
attribute := new(AttributeTypeAndValue)
escaping := false
unescapedTrailingSpaces := 0
stringFromBuffer := func() string {
s := buffer.String()
s = s[0 : len(s)-unescapedTrailingSpaces]
buffer.Reset()
unescapedTrailingSpaces = 0
return s
}
for i := 0; i < len(str); i++ {
char := str[i]
switch {
case escaping:
unescapedTrailingSpaces = 0
escaping = false
switch char {
case ' ', '"', '#', '+', ',', ';', '<', '=', '>', '\\':
buffer.WriteByte(char)
continue
}
// Not a special character, assume hex encoded octet
if len(str) == i+1 {
return nil, errors.New("got corrupted escaped character")
}
dst := []byte{0}
n, err := enchex.Decode([]byte(dst), []byte(str[i:i+2]))
if err != nil {
return nil, fmt.Errorf("failed to decode escaped character: %s", err)
} else if n != 1 {
return nil, fmt.Errorf("expected 1 byte when un-escaping, got %d", n)
}
buffer.WriteByte(dst[0])
i++
case char == '\\':
unescapedTrailingSpaces = 0
escaping = true
case char == '=':
attribute.Type = stringFromBuffer()
// Special case: If the first character in the value is # the
// following data is BER encoded so we can just fast forward
// and decode.
if len(str) > i+1 && str[i+1] == '#' {
i += 2
index := strings.IndexAny(str[i:], ",+")
data := str
if index > 0 {
data = str[i : i+index]
} else {
data = str[i:]
}
rawBER, err := enchex.DecodeString(data)
if err != nil {
return nil, fmt.Errorf("failed to decode BER encoding: %s", err)
}
packet, err := ber.DecodePacketErr(rawBER)
if err != nil {
return nil, fmt.Errorf("failed to decode BER packet: %s", err)
}
buffer.WriteString(packet.Data.String())
i += len(data) - 1
}
case char == ',' || char == '+':
// We're done with this RDN or value, push it
if len(attribute.Type) == 0 {
return nil, errors.New("incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
attribute = new(AttributeTypeAndValue)
if char == ',' {
dn.RDNs = append(dn.RDNs, rdn)
rdn = new(RelativeDN)
rdn.Attributes = make([]*AttributeTypeAndValue, 0)
}
case char == ' ' && buffer.Len() == 0:
// ignore unescaped leading spaces
continue
default:
if char == ' ' {
// Track unescaped spaces in case they are trailing and we need to remove them
unescapedTrailingSpaces++
} else {
// Reset if we see a non-space char
unescapedTrailingSpaces = 0
}
buffer.WriteByte(char)
}
}
if buffer.Len() > 0 {
if len(attribute.Type) == 0 {
return nil, errors.New("DN ended with incomplete type, value pair")
}
attribute.Value = stringFromBuffer()
rdn.Attributes = append(rdn.Attributes, attribute)
dn.RDNs = append(dn.RDNs, rdn)
}
return dn, nil
}
// Equal returns true if the DNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Returns true if they have the same number of relative distinguished names
// and corresponding relative distinguished names (by position) are the same.
func (d *DN) Equal(other *DN) bool {
if len(d.RDNs) != len(other.RDNs) {
return false
}
for i := range d.RDNs {
if !d.RDNs[i].Equal(other.RDNs[i]) {
return false
}
}
return true
}
// AncestorOf returns true if the other DN consists of at least one RDN followed by all the RDNs of the current DN.
// "ou=widgets,o=acme.com" is an ancestor of "ou=sprockets,ou=widgets,o=acme.com"
// "ou=widgets,o=acme.com" is not an ancestor of "ou=sprockets,ou=widgets,o=foo.com"
// "ou=widgets,o=acme.com" is not an ancestor of "ou=widgets,o=acme.com"
func (d *DN) AncestorOf(other *DN) bool {
if len(d.RDNs) >= len(other.RDNs) {
return false
}
// Take the last `len(d.RDNs)` RDNs from the other DN to compare against
otherRDNs := other.RDNs[len(other.RDNs)-len(d.RDNs):]
for i := range d.RDNs {
if !d.RDNs[i].Equal(otherRDNs[i]) {
return false
}
}
return true
}
// Equal returns true if the RelativeDNs are equal as defined by rfc4517 4.2.15 (distinguishedNameMatch).
// Relative distinguished names are the same if and only if they have the same number of AttributeTypeAndValues
// and each attribute of the first RDN is the same as the attribute of the second RDN with the same attribute type.
// The order of attributes is not significant.
// Case of attribute types is not significant.
func (r *RelativeDN) Equal(other *RelativeDN) bool {
if len(r.Attributes) != len(other.Attributes) {
return false
}
return r.hasAllAttributes(other.Attributes) && other.hasAllAttributes(r.Attributes)
}
func (r *RelativeDN) hasAllAttributes(attrs []*AttributeTypeAndValue) bool {
for _, attr := range attrs {
found := false
for _, myattr := range r.Attributes {
if myattr.Equal(attr) {
found = true
break
}
}
if !found {
return false
}
}
return true
}
// Equal returns true if the AttributeTypeAndValue is equivalent to the specified AttributeTypeAndValue
// Case of the attribute type is not significant
func (a *AttributeTypeAndValue) Equal(other *AttributeTypeAndValue) bool {
return strings.EqualFold(a.Type, other.Type) && a.Value == other.Value
}

4
vendor/github.com/mattermost/ldap/doc.go generated vendored Normal file
View File

@ -0,0 +1,4 @@
/*
Package ldap provides basic LDAP v3 functionality.
*/
package ldap

236
vendor/github.com/mattermost/ldap/error.go generated vendored Normal file
View File

@ -0,0 +1,236 @@
package ldap
import (
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
// LDAP Result Codes
const (
LDAPResultSuccess = 0
LDAPResultOperationsError = 1
LDAPResultProtocolError = 2
LDAPResultTimeLimitExceeded = 3
LDAPResultSizeLimitExceeded = 4
LDAPResultCompareFalse = 5
LDAPResultCompareTrue = 6
LDAPResultAuthMethodNotSupported = 7
LDAPResultStrongAuthRequired = 8
LDAPResultReferral = 10
LDAPResultAdminLimitExceeded = 11
LDAPResultUnavailableCriticalExtension = 12
LDAPResultConfidentialityRequired = 13
LDAPResultSaslBindInProgress = 14
LDAPResultNoSuchAttribute = 16
LDAPResultUndefinedAttributeType = 17
LDAPResultInappropriateMatching = 18
LDAPResultConstraintViolation = 19
LDAPResultAttributeOrValueExists = 20
LDAPResultInvalidAttributeSyntax = 21
LDAPResultNoSuchObject = 32
LDAPResultAliasProblem = 33
LDAPResultInvalidDNSyntax = 34
LDAPResultIsLeaf = 35
LDAPResultAliasDereferencingProblem = 36
LDAPResultInappropriateAuthentication = 48
LDAPResultInvalidCredentials = 49
LDAPResultInsufficientAccessRights = 50
LDAPResultBusy = 51
LDAPResultUnavailable = 52
LDAPResultUnwillingToPerform = 53
LDAPResultLoopDetect = 54
LDAPResultSortControlMissing = 60
LDAPResultOffsetRangeError = 61
LDAPResultNamingViolation = 64
LDAPResultObjectClassViolation = 65
LDAPResultNotAllowedOnNonLeaf = 66
LDAPResultNotAllowedOnRDN = 67
LDAPResultEntryAlreadyExists = 68
LDAPResultObjectClassModsProhibited = 69
LDAPResultResultsTooLarge = 70
LDAPResultAffectsMultipleDSAs = 71
LDAPResultVirtualListViewErrorOrControlError = 76
LDAPResultOther = 80
LDAPResultServerDown = 81
LDAPResultLocalError = 82
LDAPResultEncodingError = 83
LDAPResultDecodingError = 84
LDAPResultTimeout = 85
LDAPResultAuthUnknown = 86
LDAPResultFilterError = 87
LDAPResultUserCanceled = 88
LDAPResultParamError = 89
LDAPResultNoMemory = 90
LDAPResultConnectError = 91
LDAPResultNotSupported = 92
LDAPResultControlNotFound = 93
LDAPResultNoResultsReturned = 94
LDAPResultMoreResultsToReturn = 95
LDAPResultClientLoop = 96
LDAPResultReferralLimitExceeded = 97
LDAPResultInvalidResponse = 100
LDAPResultAmbiguousResponse = 101
LDAPResultTLSNotSupported = 112
LDAPResultIntermediateResponse = 113
LDAPResultUnknownType = 114
LDAPResultCanceled = 118
LDAPResultNoSuchOperation = 119
LDAPResultTooLate = 120
LDAPResultCannotCancel = 121
LDAPResultAssertionFailed = 122
LDAPResultAuthorizationDenied = 123
LDAPResultSyncRefreshRequired = 4096
ErrorNetwork = 200
ErrorFilterCompile = 201
ErrorFilterDecompile = 202
ErrorDebugging = 203
ErrorUnexpectedMessage = 204
ErrorUnexpectedResponse = 205
ErrorEmptyPassword = 206
)
// LDAPResultCodeMap contains string descriptions for LDAP error codes
var LDAPResultCodeMap = map[uint16]string{
LDAPResultSuccess: "Success",
LDAPResultOperationsError: "Operations Error",
LDAPResultProtocolError: "Protocol Error",
LDAPResultTimeLimitExceeded: "Time Limit Exceeded",
LDAPResultSizeLimitExceeded: "Size Limit Exceeded",
LDAPResultCompareFalse: "Compare False",
LDAPResultCompareTrue: "Compare True",
LDAPResultAuthMethodNotSupported: "Auth Method Not Supported",
LDAPResultStrongAuthRequired: "Strong Auth Required",
LDAPResultReferral: "Referral",
LDAPResultAdminLimitExceeded: "Admin Limit Exceeded",
LDAPResultUnavailableCriticalExtension: "Unavailable Critical Extension",
LDAPResultConfidentialityRequired: "Confidentiality Required",
LDAPResultSaslBindInProgress: "Sasl Bind In Progress",
LDAPResultNoSuchAttribute: "No Such Attribute",
LDAPResultUndefinedAttributeType: "Undefined Attribute Type",
LDAPResultInappropriateMatching: "Inappropriate Matching",
LDAPResultConstraintViolation: "Constraint Violation",
LDAPResultAttributeOrValueExists: "Attribute Or Value Exists",
LDAPResultInvalidAttributeSyntax: "Invalid Attribute Syntax",
LDAPResultNoSuchObject: "No Such Object",
LDAPResultAliasProblem: "Alias Problem",
LDAPResultInvalidDNSyntax: "Invalid DN Syntax",
LDAPResultIsLeaf: "Is Leaf",
LDAPResultAliasDereferencingProblem: "Alias Dereferencing Problem",
LDAPResultInappropriateAuthentication: "Inappropriate Authentication",
LDAPResultInvalidCredentials: "Invalid Credentials",
LDAPResultInsufficientAccessRights: "Insufficient Access Rights",
LDAPResultBusy: "Busy",
LDAPResultUnavailable: "Unavailable",
LDAPResultUnwillingToPerform: "Unwilling To Perform",
LDAPResultLoopDetect: "Loop Detect",
LDAPResultSortControlMissing: "Sort Control Missing",
LDAPResultOffsetRangeError: "Result Offset Range Error",
LDAPResultNamingViolation: "Naming Violation",
LDAPResultObjectClassViolation: "Object Class Violation",
LDAPResultResultsTooLarge: "Results Too Large",
LDAPResultNotAllowedOnNonLeaf: "Not Allowed On Non Leaf",
LDAPResultNotAllowedOnRDN: "Not Allowed On RDN",
LDAPResultEntryAlreadyExists: "Entry Already Exists",
LDAPResultObjectClassModsProhibited: "Object Class Mods Prohibited",
LDAPResultAffectsMultipleDSAs: "Affects Multiple DSAs",
LDAPResultVirtualListViewErrorOrControlError: "Failed because of a problem related to the virtual list view",
LDAPResultOther: "Other",
LDAPResultServerDown: "Cannot establish a connection",
LDAPResultLocalError: "An error occurred",
LDAPResultEncodingError: "LDAP encountered an error while encoding",
LDAPResultDecodingError: "LDAP encountered an error while decoding",
LDAPResultTimeout: "LDAP timeout while waiting for a response from the server",
LDAPResultAuthUnknown: "The auth method requested in a bind request is unknown",
LDAPResultFilterError: "An error occurred while encoding the given search filter",
LDAPResultUserCanceled: "The user canceled the operation",
LDAPResultParamError: "An invalid parameter was specified",
LDAPResultNoMemory: "Out of memory error",
LDAPResultConnectError: "A connection to the server could not be established",
LDAPResultNotSupported: "An attempt has been made to use a feature not supported LDAP",
LDAPResultControlNotFound: "The controls required to perform the requested operation were not found",
LDAPResultNoResultsReturned: "No results were returned from the server",
LDAPResultMoreResultsToReturn: "There are more results in the chain of results",
LDAPResultClientLoop: "A loop has been detected. For example when following referrals",
LDAPResultReferralLimitExceeded: "The referral hop limit has been exceeded",
LDAPResultCanceled: "Operation was canceled",
LDAPResultNoSuchOperation: "Server has no knowledge of the operation requested for cancellation",
LDAPResultTooLate: "Too late to cancel the outstanding operation",
LDAPResultCannotCancel: "The identified operation does not support cancellation or the cancel operation cannot be performed",
LDAPResultAssertionFailed: "An assertion control given in the LDAP operation evaluated to false causing the operation to not be performed",
LDAPResultSyncRefreshRequired: "Refresh Required",
LDAPResultInvalidResponse: "Invalid Response",
LDAPResultAmbiguousResponse: "Ambiguous Response",
LDAPResultTLSNotSupported: "Tls Not Supported",
LDAPResultIntermediateResponse: "Intermediate Response",
LDAPResultUnknownType: "Unknown Type",
LDAPResultAuthorizationDenied: "Authorization Denied",
ErrorNetwork: "Network Error",
ErrorFilterCompile: "Filter Compile Error",
ErrorFilterDecompile: "Filter Decompile Error",
ErrorDebugging: "Debugging Error",
ErrorUnexpectedMessage: "Unexpected Message",
ErrorUnexpectedResponse: "Unexpected Response",
ErrorEmptyPassword: "Empty password not allowed by the client",
}
// Error holds LDAP error information
type Error struct {
// Err is the underlying error
Err error
// ResultCode is the LDAP error code
ResultCode uint16
// MatchedDN is the matchedDN returned if any
MatchedDN string
}
func (e *Error) Error() string {
return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error())
}
// GetLDAPError creates an Error out of a BER packet representing a LDAPResult
// The return is an error object. It can be casted to a Error structure.
// This function returns nil if resultCode in the LDAPResult sequence is success(0).
func GetLDAPError(packet *ber.Packet) error {
if packet == nil {
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty packet")}
}
if len(packet.Children) >= 2 {
response := packet.Children[1]
if response == nil {
return &Error{ResultCode: ErrorUnexpectedResponse, Err: fmt.Errorf("Empty response in packet")}
}
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) >= 3 {
resultCode := uint16(response.Children[0].Value.(int64))
if resultCode == 0 { // No error
return nil
}
return &Error{ResultCode: resultCode, MatchedDN: response.Children[1].Value.(string),
Err: fmt.Errorf("%s", response.Children[2].Value.(string))}
}
}
return &Error{ResultCode: ErrorNetwork, Err: fmt.Errorf("Invalid packet format")}
}
// NewError creates an LDAP error with the given code and underlying error
func NewError(resultCode uint16, err error) error {
return &Error{ResultCode: resultCode, Err: err}
}
// IsErrorWithCode returns true if the given error is an LDAP error with the given result code
func IsErrorWithCode(err error, desiredResultCode uint16) bool {
if err == nil {
return false
}
serverError, ok := err.(*Error)
if !ok {
return false
}
return serverError.ResultCode == desiredResultCode
}

465
vendor/github.com/mattermost/ldap/filter.go generated vendored Normal file
View File

@ -0,0 +1,465 @@
package ldap
import (
"bytes"
hexpac "encoding/hex"
"errors"
"fmt"
"strings"
"unicode/utf8"
"github.com/go-asn1-ber/asn1-ber"
)
// Filter choices
const (
FilterAnd = 0
FilterOr = 1
FilterNot = 2
FilterEqualityMatch = 3
FilterSubstrings = 4
FilterGreaterOrEqual = 5
FilterLessOrEqual = 6
FilterPresent = 7
FilterApproxMatch = 8
FilterExtensibleMatch = 9
)
// FilterMap contains human readable descriptions of Filter choices
var FilterMap = map[uint64]string{
FilterAnd: "And",
FilterOr: "Or",
FilterNot: "Not",
FilterEqualityMatch: "Equality Match",
FilterSubstrings: "Substrings",
FilterGreaterOrEqual: "Greater Or Equal",
FilterLessOrEqual: "Less Or Equal",
FilterPresent: "Present",
FilterApproxMatch: "Approx Match",
FilterExtensibleMatch: "Extensible Match",
}
// SubstringFilter options
const (
FilterSubstringsInitial = 0
FilterSubstringsAny = 1
FilterSubstringsFinal = 2
)
// FilterSubstringsMap contains human readable descriptions of SubstringFilter choices
var FilterSubstringsMap = map[uint64]string{
FilterSubstringsInitial: "Substrings Initial",
FilterSubstringsAny: "Substrings Any",
FilterSubstringsFinal: "Substrings Final",
}
// MatchingRuleAssertion choices
const (
MatchingRuleAssertionMatchingRule = 1
MatchingRuleAssertionType = 2
MatchingRuleAssertionMatchValue = 3
MatchingRuleAssertionDNAttributes = 4
)
// MatchingRuleAssertionMap contains human readable descriptions of MatchingRuleAssertion choices
var MatchingRuleAssertionMap = map[uint64]string{
MatchingRuleAssertionMatchingRule: "Matching Rule Assertion Matching Rule",
MatchingRuleAssertionType: "Matching Rule Assertion Type",
MatchingRuleAssertionMatchValue: "Matching Rule Assertion Match Value",
MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
}
// CompileFilter converts a string representation of a filter into a BER-encoded packet
func CompileFilter(filter string) (*ber.Packet, error) {
if len(filter) == 0 || filter[0] != '(' {
return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
}
packet, pos, err := compileFilter(filter, 1)
if err != nil {
return nil, err
}
switch {
case pos > len(filter):
return nil, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
case pos < len(filter):
return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
}
return packet, nil
}
// DecompileFilter converts a packet representation of a filter into a string representation
func DecompileFilter(packet *ber.Packet) (ret string, err error) {
defer func() {
if r := recover(); r != nil {
err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
}
}()
ret = "("
err = nil
childStr := ""
switch packet.Tag {
case FilterAnd:
ret += "&"
for _, child := range packet.Children {
childStr, err = DecompileFilter(child)
if err != nil {
return
}
ret += childStr
}
case FilterOr:
ret += "|"
for _, child := range packet.Children {
childStr, err = DecompileFilter(child)
if err != nil {
return
}
ret += childStr
}
case FilterNot:
ret += "!"
childStr, err = DecompileFilter(packet.Children[0])
if err != nil {
return
}
ret += childStr
case FilterSubstrings:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "="
for i, child := range packet.Children[1].Children {
if i == 0 && child.Tag != FilterSubstringsInitial {
ret += "*"
}
ret += EscapeFilter(ber.DecodeString(child.Data.Bytes()))
if child.Tag != FilterSubstringsFinal {
ret += "*"
}
}
case FilterEqualityMatch:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
case FilterGreaterOrEqual:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += ">="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
case FilterLessOrEqual:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "<="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
case FilterPresent:
ret += ber.DecodeString(packet.Data.Bytes())
ret += "=*"
case FilterApproxMatch:
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
ret += "~="
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
case FilterExtensibleMatch:
attr := ""
dnAttributes := false
matchingRule := ""
value := ""
for _, child := range packet.Children {
switch child.Tag {
case MatchingRuleAssertionMatchingRule:
matchingRule = ber.DecodeString(child.Data.Bytes())
case MatchingRuleAssertionType:
attr = ber.DecodeString(child.Data.Bytes())
case MatchingRuleAssertionMatchValue:
value = ber.DecodeString(child.Data.Bytes())
case MatchingRuleAssertionDNAttributes:
dnAttributes = child.Value.(bool)
}
}
if len(attr) > 0 {
ret += attr
}
if dnAttributes {
ret += ":dn"
}
if len(matchingRule) > 0 {
ret += ":"
ret += matchingRule
}
ret += ":="
ret += EscapeFilter(value)
}
ret += ")"
return
}
func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
for pos < len(filter) && filter[pos] == '(' {
child, newPos, err := compileFilter(filter, pos+1)
if err != nil {
return pos, err
}
pos = newPos
parent.AppendChild(child)
}
if pos == len(filter) {
return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
}
return pos + 1, nil
}
func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
var (
packet *ber.Packet
err error
)
defer func() {
if r := recover(); r != nil {
err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter"))
}
}()
newPos := pos
currentRune, currentWidth := utf8.DecodeRuneInString(filter[newPos:])
switch currentRune {
case utf8.RuneError:
return nil, 0, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos))
case '(':
packet, newPos, err = compileFilter(filter, pos+currentWidth)
newPos++
return packet, newPos, err
case '&':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
newPos, err = compileFilterSet(filter, pos+currentWidth, packet)
return packet, newPos, err
case '|':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
newPos, err = compileFilterSet(filter, pos+currentWidth, packet)
return packet, newPos, err
case '!':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
var child *ber.Packet
child, newPos, err = compileFilter(filter, pos+currentWidth)
packet.AppendChild(child)
return packet, newPos, err
default:
const (
stateReadingAttr = 0
stateReadingExtensibleMatchingRule = 1
stateReadingCondition = 2
)
state := stateReadingAttr
attribute := ""
extensibleDNAttributes := false
extensibleMatchingRule := ""
condition := ""
for newPos < len(filter) {
remainingFilter := filter[newPos:]
currentRune, currentWidth = utf8.DecodeRuneInString(remainingFilter)
if currentRune == ')' {
break
}
if currentRune == utf8.RuneError {
return packet, newPos, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos))
}
switch state {
case stateReadingAttr:
switch {
// Extensible rule, with only DN-matching
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
extensibleDNAttributes = true
state = stateReadingCondition
newPos += 5
// Extensible rule, with DN-matching and a matching OID
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:"):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
extensibleDNAttributes = true
state = stateReadingExtensibleMatchingRule
newPos += 4
// Extensible rule, with attr only
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
state = stateReadingCondition
newPos += 2
// Extensible rule, with no DN attribute matching
case currentRune == ':':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
state = stateReadingExtensibleMatchingRule
newPos++
// Equality condition
case currentRune == '=':
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
state = stateReadingCondition
newPos++
// Greater-than or equal
case currentRune == '>' && strings.HasPrefix(remainingFilter, ">="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
state = stateReadingCondition
newPos += 2
// Less-than or equal
case currentRune == '<' && strings.HasPrefix(remainingFilter, "<="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
state = stateReadingCondition
newPos += 2
// Approx
case currentRune == '~' && strings.HasPrefix(remainingFilter, "~="):
packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterApproxMatch])
state = stateReadingCondition
newPos += 2
// Still reading the attribute name
default:
attribute += fmt.Sprintf("%c", currentRune)
newPos += currentWidth
}
case stateReadingExtensibleMatchingRule:
switch {
// Matching rule OID is done
case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
state = stateReadingCondition
newPos += 2
// Still reading the matching rule oid
default:
extensibleMatchingRule += fmt.Sprintf("%c", currentRune)
newPos += currentWidth
}
case stateReadingCondition:
// append to the condition
condition += fmt.Sprintf("%c", currentRune)
newPos += currentWidth
}
}
if newPos == len(filter) {
err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
return packet, newPos, err
}
if packet == nil {
err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
return packet, newPos, err
}
switch {
case packet.Tag == FilterExtensibleMatch:
// MatchingRuleAssertion ::= SEQUENCE {
// matchingRule [1] MatchingRuleID OPTIONAL,
// type [2] AttributeDescription OPTIONAL,
// matchValue [3] AssertionValue,
// dnAttributes [4] BOOLEAN DEFAULT FALSE
// }
// Include the matching rule oid, if specified
if len(extensibleMatchingRule) > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
}
// Include the attribute, if specified
if len(attribute) > 0 {
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType]))
}
// Add the value (only required child)
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
if encodeErr != nil {
return packet, newPos, encodeErr
}
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchValue, encodedString, MatchingRuleAssertionMap[MatchingRuleAssertionMatchValue]))
// Defaults to false, so only include in the sequence if true
if extensibleDNAttributes {
packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes]))
}
case packet.Tag == FilterEqualityMatch && condition == "*":
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
packet.Tag = FilterSubstrings
packet.Description = FilterMap[uint64(packet.Tag)]
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
parts := strings.Split(condition, "*")
for i, part := range parts {
if part == "" {
continue
}
var tag ber.Tag
switch i {
case 0:
tag = FilterSubstringsInitial
case len(parts) - 1:
tag = FilterSubstringsFinal
default:
tag = FilterSubstringsAny
}
encodedString, encodeErr := escapedStringToEncodedBytes(part)
if encodeErr != nil {
return packet, newPos, encodeErr
}
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)]))
}
packet.AppendChild(seq)
default:
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
if encodeErr != nil {
return packet, newPos, encodeErr
}
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
}
newPos += currentWidth
return packet, newPos, err
}
}
// Convert from "ABC\xx\xx\xx" form to literal bytes for transport
func escapedStringToEncodedBytes(escapedString string) (string, error) {
var buffer bytes.Buffer
i := 0
for i < len(escapedString) {
currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:])
if currentRune == utf8.RuneError {
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i))
}
// Check for escaped hex characters and convert them to their literal value for transport.
if currentRune == '\\' {
// http://tools.ietf.org/search/rfc4515
// \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not
// being a member of UTF1SUBSET.
if i+2 > len(escapedString) {
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
}
escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
if decodeErr != nil {
return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))
}
buffer.WriteByte(escByte[0])
i += 2 // +1 from end of loop, so 3 total for \xx.
} else {
buffer.WriteRune(currentRune)
}
i += currentWidth
}
return buffer.String(), nil
}

345
vendor/github.com/mattermost/ldap/ldap.go generated vendored Normal file
View File

@ -0,0 +1,345 @@
package ldap
import (
"fmt"
"io/ioutil"
"os"
ber "github.com/go-asn1-ber/asn1-ber"
)
// LDAP Application Codes
const (
ApplicationBindRequest = 0
ApplicationBindResponse = 1
ApplicationUnbindRequest = 2
ApplicationSearchRequest = 3
ApplicationSearchResultEntry = 4
ApplicationSearchResultDone = 5
ApplicationModifyRequest = 6
ApplicationModifyResponse = 7
ApplicationAddRequest = 8
ApplicationAddResponse = 9
ApplicationDelRequest = 10
ApplicationDelResponse = 11
ApplicationModifyDNRequest = 12
ApplicationModifyDNResponse = 13
ApplicationCompareRequest = 14
ApplicationCompareResponse = 15
ApplicationAbandonRequest = 16
ApplicationSearchResultReference = 19
ApplicationExtendedRequest = 23
ApplicationExtendedResponse = 24
)
// ApplicationMap contains human readable descriptions of LDAP Application Codes
var ApplicationMap = map[uint8]string{
ApplicationBindRequest: "Bind Request",
ApplicationBindResponse: "Bind Response",
ApplicationUnbindRequest: "Unbind Request",
ApplicationSearchRequest: "Search Request",
ApplicationSearchResultEntry: "Search Result Entry",
ApplicationSearchResultDone: "Search Result Done",
ApplicationModifyRequest: "Modify Request",
ApplicationModifyResponse: "Modify Response",
ApplicationAddRequest: "Add Request",
ApplicationAddResponse: "Add Response",
ApplicationDelRequest: "Del Request",
ApplicationDelResponse: "Del Response",
ApplicationModifyDNRequest: "Modify DN Request",
ApplicationModifyDNResponse: "Modify DN Response",
ApplicationCompareRequest: "Compare Request",
ApplicationCompareResponse: "Compare Response",
ApplicationAbandonRequest: "Abandon Request",
ApplicationSearchResultReference: "Search Result Reference",
ApplicationExtendedRequest: "Extended Request",
ApplicationExtendedResponse: "Extended Response",
}
// Ldap Behera Password Policy Draft 10 (https://tools.ietf.org/html/draft-behera-ldap-password-policy-10)
const (
BeheraPasswordExpired = 0
BeheraAccountLocked = 1
BeheraChangeAfterReset = 2
BeheraPasswordModNotAllowed = 3
BeheraMustSupplyOldPassword = 4
BeheraInsufficientPasswordQuality = 5
BeheraPasswordTooShort = 6
BeheraPasswordTooYoung = 7
BeheraPasswordInHistory = 8
)
// BeheraPasswordPolicyErrorMap contains human readable descriptions of Behera Password Policy error codes
var BeheraPasswordPolicyErrorMap = map[int8]string{
BeheraPasswordExpired: "Password expired",
BeheraAccountLocked: "Account locked",
BeheraChangeAfterReset: "Password must be changed",
BeheraPasswordModNotAllowed: "Policy prevents password modification",
BeheraMustSupplyOldPassword: "Policy requires old password in order to change password",
BeheraInsufficientPasswordQuality: "Password fails quality checks",
BeheraPasswordTooShort: "Password is too short for policy",
BeheraPasswordTooYoung: "Password has been changed too recently",
BeheraPasswordInHistory: "New password is in list of old passwords",
}
// Adds descriptions to an LDAP Response packet for debugging
func addLDAPDescriptions(packet *ber.Packet) (err error) {
defer func() {
if r := recover(); r != nil {
err = NewError(ErrorDebugging, fmt.Errorf("ldap: cannot process packet to add descriptions: %s", r))
}
}()
packet.Description = "LDAP Response"
packet.Children[0].Description = "Message ID"
application := uint8(packet.Children[1].Tag)
packet.Children[1].Description = ApplicationMap[application]
switch application {
case ApplicationBindRequest:
err = addRequestDescriptions(packet)
case ApplicationBindResponse:
err = addDefaultLDAPResponseDescriptions(packet)
case ApplicationUnbindRequest:
err = addRequestDescriptions(packet)
case ApplicationSearchRequest:
err = addRequestDescriptions(packet)
case ApplicationSearchResultEntry:
packet.Children[1].Children[0].Description = "Object Name"
packet.Children[1].Children[1].Description = "Attributes"
for _, child := range packet.Children[1].Children[1].Children {
child.Description = "Attribute"
child.Children[0].Description = "Attribute Name"
child.Children[1].Description = "Attribute Values"
for _, grandchild := range child.Children[1].Children {
grandchild.Description = "Attribute Value"
}
}
if len(packet.Children) == 3 {
err = addControlDescriptions(packet.Children[2])
}
case ApplicationSearchResultDone:
err = addDefaultLDAPResponseDescriptions(packet)
case ApplicationModifyRequest:
err = addRequestDescriptions(packet)
case ApplicationModifyResponse:
case ApplicationAddRequest:
err = addRequestDescriptions(packet)
case ApplicationAddResponse:
case ApplicationDelRequest:
err = addRequestDescriptions(packet)
case ApplicationDelResponse:
case ApplicationModifyDNRequest:
err = addRequestDescriptions(packet)
case ApplicationModifyDNResponse:
case ApplicationCompareRequest:
err = addRequestDescriptions(packet)
case ApplicationCompareResponse:
case ApplicationAbandonRequest:
err = addRequestDescriptions(packet)
case ApplicationSearchResultReference:
case ApplicationExtendedRequest:
err = addRequestDescriptions(packet)
case ApplicationExtendedResponse:
}
return err
}
func addControlDescriptions(packet *ber.Packet) error {
packet.Description = "Controls"
for _, child := range packet.Children {
var value *ber.Packet
controlType := ""
child.Description = "Control"
switch len(child.Children) {
case 0:
// at least one child is required for control type
return fmt.Errorf("at least one child is required for control type")
case 1:
// just type, no criticality or value
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
case 2:
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
// Children[1] could be criticality or value (both are optional)
// duck-type on whether this is a boolean
if _, ok := child.Children[1].Value.(bool); ok {
child.Children[1].Description = "Criticality"
} else {
child.Children[1].Description = "Control Value"
value = child.Children[1]
}
case 3:
// criticality and value present
controlType = child.Children[0].Value.(string)
child.Children[0].Description = "Control Type (" + ControlTypeMap[controlType] + ")"
child.Children[1].Description = "Criticality"
child.Children[2].Description = "Control Value"
value = child.Children[2]
default:
// more than 3 children is invalid
return fmt.Errorf("more than 3 children for control packet found")
}
if value == nil {
continue
}
switch controlType {
case ControlTypePaging:
value.Description += " (Paging)"
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
value.Data.Truncate(0)
value.Value = nil
valueChildren.Children[1].Value = valueChildren.Children[1].Data.Bytes()
value.AppendChild(valueChildren)
}
value.Children[0].Description = "Real Search Control Value"
value.Children[0].Children[0].Description = "Paging Size"
value.Children[0].Children[1].Description = "Cookie"
case ControlTypeBeheraPasswordPolicy:
value.Description += " (Password Policy - Behera Draft)"
if value.Value != nil {
valueChildren, err := ber.DecodePacketErr(value.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
value.Data.Truncate(0)
value.Value = nil
value.AppendChild(valueChildren)
}
sequence := value.Children[0]
for _, child := range sequence.Children {
if child.Tag == 0 {
//Warning
warningPacket := child.Children[0]
packet, err := ber.DecodePacketErr(warningPacket.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int64)
if ok {
if warningPacket.Tag == 0 {
//timeBeforeExpiration
value.Description += " (TimeBeforeExpiration)"
warningPacket.Value = val
} else if warningPacket.Tag == 1 {
//graceAuthNsRemaining
value.Description += " (GraceAuthNsRemaining)"
warningPacket.Value = val
}
}
} else if child.Tag == 1 {
// Error
packet, err := ber.DecodePacketErr(child.Data.Bytes())
if err != nil {
return fmt.Errorf("failed to decode data bytes: %s", err)
}
val, ok := packet.Value.(int8)
if !ok {
val = -1
}
child.Description = "Error"
child.Value = val
}
}
}
}
return nil
}
func addRequestDescriptions(packet *ber.Packet) error {
packet.Description = "LDAP Request"
packet.Children[0].Description = "Message ID"
packet.Children[1].Description = ApplicationMap[uint8(packet.Children[1].Tag)]
if len(packet.Children) == 3 {
return addControlDescriptions(packet.Children[2])
}
return nil
}
func addDefaultLDAPResponseDescriptions(packet *ber.Packet) error {
resultCode := uint16(LDAPResultSuccess)
matchedDN := ""
description := "Success"
if err := GetLDAPError(packet); err != nil {
resultCode = err.(*Error).ResultCode
matchedDN = err.(*Error).MatchedDN
description = "Error Message"
}
packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[resultCode] + ")"
packet.Children[1].Children[1].Description = "Matched DN (" + matchedDN + ")"
packet.Children[1].Children[2].Description = description
if len(packet.Children[1].Children) > 3 {
packet.Children[1].Children[3].Description = "Referral"
}
if len(packet.Children) == 3 {
return addControlDescriptions(packet.Children[2])
}
return nil
}
// DebugBinaryFile reads and prints packets from the given filename
func DebugBinaryFile(fileName string) error {
file, err := ioutil.ReadFile(fileName)
if err != nil {
return NewError(ErrorDebugging, err)
}
ber.PrintBytes(os.Stdout, file, "")
packet, err := ber.DecodePacketErr(file)
if err != nil {
return fmt.Errorf("failed to decode packet: %s", err)
}
if err := addLDAPDescriptions(packet); err != nil {
return err
}
ber.PrintPacket(packet)
return nil
}
var hex = "0123456789abcdef"
func mustEscape(c byte) bool {
return c > 0x7f || c == '(' || c == ')' || c == '\\' || c == '*' || c == 0
}
// EscapeFilter escapes from the provided LDAP filter string the special
// characters in the set `()*\` and those out of the range 0 < c < 0x80,
// as defined in RFC4515.
func EscapeFilter(filter string) string {
escape := 0
for i := 0; i < len(filter); i++ {
if mustEscape(filter[i]) {
escape++
}
}
if escape == 0 {
return filter
}
buf := make([]byte, len(filter)+escape*2)
for i, j := 0, 0; i < len(filter); i++ {
c := filter[i]
if mustEscape(c) {
buf[j+0] = '\\'
buf[j+1] = hex[c>>4]
buf[j+2] = hex[c&0xf]
j += 3
} else {
buf[j] = c
j++
}
}
return string(buf)
}

86
vendor/github.com/mattermost/ldap/moddn.go generated vendored Normal file
View File

@ -0,0 +1,86 @@
// Package ldap - moddn.go contains ModifyDN functionality
//
// https://tools.ietf.org/html/rfc4511
//
// ModifyDNRequest ::= [APPLICATION 12] SEQUENCE {
// entry LDAPDN,
// newrdn RelativeLDAPDN,
// deleteoldrdn BOOLEAN,
// newSuperior [0] LDAPDN OPTIONAL }
package ldap
import (
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
// ModifyDNRequest holds the request to modify a DN
type ModifyDNRequest struct {
DN string
NewRDN string
DeleteOldRDN bool
NewSuperior string
}
// NewModifyDNRequest creates a new request which can be passed to ModifyDN().
//
// To move an object in the tree, set the "newSup" to the new parent entry DN. Use an
// empty string for just changing the object's RDN.
//
// For moving the object without renaming, the "rdn" must be the first
// RDN of the given DN.
//
// A call like
//
// mdnReq := NewModifyDNRequest("uid=someone,dc=example,dc=org", "uid=newname", true, "")
//
// will setup the request to just rename uid=someone,dc=example,dc=org to
// uid=newname,dc=example,dc=org.
func NewModifyDNRequest(dn string, rdn string, delOld bool, newSup string) *ModifyDNRequest {
return &ModifyDNRequest{
DN: dn,
NewRDN: rdn,
DeleteOldRDN: delOld,
NewSuperior: newSup,
}
}
func (req *ModifyDNRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyDNRequest, nil, "Modify DN Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.NewRDN, "New RDN"))
pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.DeleteOldRDN, "Delete old RDN"))
if req.NewSuperior != "" {
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.NewSuperior, "New Superior"))
}
envelope.AppendChild(pkt)
return nil
}
// ModifyDN renames the given DN and optionally move to another base (when the "newSup" argument
// to NewModifyDNRequest() is not "").
func (l *Conn) ModifyDN(m *ModifyDNRequest) error {
msgCtx, err := l.doRequest(m)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}
tag := packet.Children[1].Tag
if tag == ApplicationModifyDNResponse {
err := GetLDAPError(packet)
if err != nil {
return err
}
} else {
l.Debug.Log("Unexpected Response tag", mlog.Uint("tag", tag))
}
return nil
}

151
vendor/github.com/mattermost/ldap/modify.go generated vendored Normal file
View File

@ -0,0 +1,151 @@
// File contains Modify functionality
//
// https://tools.ietf.org/html/rfc4511
//
// ModifyRequest ::= [APPLICATION 6] SEQUENCE {
// object LDAPDN,
// changes SEQUENCE OF change SEQUENCE {
// operation ENUMERATED {
// add (0),
// delete (1),
// replace (2),
// ... },
// modification PartialAttribute } }
//
// PartialAttribute ::= SEQUENCE {
// type AttributeDescription,
// vals SET OF value AttributeValue }
//
// AttributeDescription ::= LDAPString
// -- Constrained to <attributedescription>
// -- [RFC4512]
//
// AttributeValue ::= OCTET STRING
//
package ldap
import (
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
// Change operation choices
const (
AddAttribute = 0
DeleteAttribute = 1
ReplaceAttribute = 2
)
// PartialAttribute for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511
type PartialAttribute struct {
// Type is the type of the partial attribute
Type string
// Vals are the values of the partial attribute
Vals []string
}
func (p *PartialAttribute) encode() *ber.Packet {
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "PartialAttribute")
seq.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, p.Type, "Type"))
set := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "AttributeValue")
for _, value := range p.Vals {
set.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Vals"))
}
seq.AppendChild(set)
return seq
}
// Change for a ModifyRequest as defined in https://tools.ietf.org/html/rfc4511
type Change struct {
// Operation is the type of change to be made
Operation uint
// Modification is the attribute to be modified
Modification PartialAttribute
}
func (c *Change) encode() *ber.Packet {
change := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Change")
change.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(c.Operation), "Operation"))
change.AppendChild(c.Modification.encode())
return change
}
// ModifyRequest as defined in https://tools.ietf.org/html/rfc4511
type ModifyRequest struct {
// DN is the distinguishedName of the directory entry to modify
DN string
// Changes contain the attributes to modify
Changes []Change
// Controls hold optional controls to send with the request
Controls []Control
}
// Add appends the given attribute to the list of changes to be made
func (req *ModifyRequest) Add(attrType string, attrVals []string) {
req.appendChange(AddAttribute, attrType, attrVals)
}
// Delete appends the given attribute to the list of changes to be made
func (req *ModifyRequest) Delete(attrType string, attrVals []string) {
req.appendChange(DeleteAttribute, attrType, attrVals)
}
// Replace appends the given attribute to the list of changes to be made
func (req *ModifyRequest) Replace(attrType string, attrVals []string) {
req.appendChange(ReplaceAttribute, attrType, attrVals)
}
func (req *ModifyRequest) appendChange(operation uint, attrType string, attrVals []string) {
req.Changes = append(req.Changes, Change{operation, PartialAttribute{Type: attrType, Vals: attrVals}})
}
func (req *ModifyRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationModifyRequest, nil, "Modify Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.DN, "DN"))
changes := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Changes")
for _, change := range req.Changes {
changes.AppendChild(change.encode())
}
pkt.AppendChild(changes)
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// NewModifyRequest creates a modify request for the given DN
func NewModifyRequest(dn string, controls []Control) *ModifyRequest {
return &ModifyRequest{
DN: dn,
Controls: controls,
}
}
// Modify performs the ModifyRequest
func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
msgCtx, err := l.doRequest(modifyRequest)
if err != nil {
return err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return err
}
tag := packet.Children[1].Tag
if tag == ApplicationModifyResponse {
err := GetLDAPError(packet)
if err != nil {
return err
}
} else {
l.Debug.Log("Unexpected Response tag", mlog.Uint("tag", tag))
}
return nil
}

131
vendor/github.com/mattermost/ldap/passwdmodify.go generated vendored Normal file
View File

@ -0,0 +1,131 @@
// This file contains the password modify extended operation as specified in rfc 3062
//
// https://tools.ietf.org/html/rfc3062
//
package ldap
import (
"fmt"
ber "github.com/go-asn1-ber/asn1-ber"
)
const (
passwordModifyOID = "1.3.6.1.4.1.4203.1.11.1"
)
// PasswordModifyRequest implements the Password Modify Extended Operation as defined in https://www.ietf.org/rfc/rfc3062.txt
type PasswordModifyRequest struct {
// UserIdentity is an optional string representation of the user associated with the request.
// This string may or may not be an LDAPDN [RFC2253].
// If no UserIdentity field is present, the request acts up upon the password of the user currently associated with the LDAP session
UserIdentity string
// OldPassword, if present, contains the user's current password
OldPassword string
// NewPassword, if present, contains the desired password for this user
NewPassword string
}
// PasswordModifyResult holds the server response to a PasswordModifyRequest
type PasswordModifyResult struct {
// GeneratedPassword holds a password generated by the server, if present
GeneratedPassword string
// Referral are the returned referral
Referral string
}
func (req *PasswordModifyRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Password Modify Extended Operation")
pkt.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, passwordModifyOID, "Extended Request Name: Password Modify OID"))
extendedRequestValue := ber.Encode(ber.ClassContext, ber.TypePrimitive, 1, nil, "Extended Request Value: Password Modify Request")
passwordModifyRequestValue := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Password Modify Request")
if req.UserIdentity != "" {
passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, req.UserIdentity, "User Identity"))
}
if req.OldPassword != "" {
passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 1, req.OldPassword, "Old Password"))
}
if req.NewPassword != "" {
passwordModifyRequestValue.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 2, req.NewPassword, "New Password"))
}
extendedRequestValue.AppendChild(passwordModifyRequestValue)
pkt.AppendChild(extendedRequestValue)
envelope.AppendChild(pkt)
return nil
}
// NewPasswordModifyRequest creates a new PasswordModifyRequest
//
// According to the RFC 3602:
// userIdentity is a string representing the user associated with the request.
// This string may or may not be an LDAPDN (RFC 2253).
// If userIdentity is empty then the operation will act on the user associated
// with the session.
//
// oldPassword is the current user's password, it can be empty or it can be
// needed depending on the session user access rights (usually an administrator
// can change a user's password without knowing the current one) and the
// password policy (see pwdSafeModify password policy's attribute)
//
// newPassword is the desired user's password. If empty the server can return
// an error or generate a new password that will be available in the
// PasswordModifyResult.GeneratedPassword
//
func NewPasswordModifyRequest(userIdentity string, oldPassword string, newPassword string) *PasswordModifyRequest {
return &PasswordModifyRequest{
UserIdentity: userIdentity,
OldPassword: oldPassword,
NewPassword: newPassword,
}
}
// PasswordModify performs the modification request
func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*PasswordModifyResult, error) {
msgCtx, err := l.doRequest(passwordModifyRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
result := &PasswordModifyResult{}
if packet.Children[1].Tag == ApplicationExtendedResponse {
err := GetLDAPError(packet)
if err != nil {
if IsErrorWithCode(err, LDAPResultReferral) {
for _, child := range packet.Children[1].Children {
if child.Tag == 3 {
result.Referral = child.Children[0].Value.(string)
}
}
}
return result, err
}
} else {
return nil, NewError(ErrorUnexpectedResponse, fmt.Errorf("unexpected Response: %d", packet.Children[1].Tag))
}
extendedResponse := packet.Children[1]
for _, child := range extendedResponse.Children {
if child.Tag == 11 {
passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes())
if len(passwordModifyResponseValue.Children) == 1 {
if passwordModifyResponseValue.Children[0].Tag == 0 {
result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes())
}
}
}
}
return result, nil
}

66
vendor/github.com/mattermost/ldap/request.go generated vendored Normal file
View File

@ -0,0 +1,66 @@
package ldap
import (
"errors"
ber "github.com/go-asn1-ber/asn1-ber"
"github.com/mattermost/mattermost/server/public/shared/mlog"
)
var (
errRespChanClosed = errors.New("ldap: response channel closed")
errCouldNotRetMsg = errors.New("ldap: could not retrieve message")
)
type request interface {
appendTo(*ber.Packet) error
}
type requestFunc func(*ber.Packet) error
func (f requestFunc) appendTo(p *ber.Packet) error {
return f(p)
}
func (l *Conn) doRequest(req request) (*messageContext, error) {
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
if err := req.appendTo(packet); err != nil {
return nil, err
}
l.Debug.Log("Sending package", PacketToField(packet))
msgCtx, err := l.sendMessage(packet)
if err != nil {
return nil, err
}
l.Debug.Log("Send package", mlog.Int("id", msgCtx.id))
return msgCtx, nil
}
func (l *Conn) readPacket(msgCtx *messageContext) (*ber.Packet, error) {
l.Debug.Log("Waiting for response", mlog.Int("id", msgCtx.id))
packetResponse, ok := <-msgCtx.responses
if !ok {
return nil, NewError(ErrorNetwork, errRespChanClosed)
}
packet, err := packetResponse.ReadPacket()
if l.Debug.Enabled() {
if err := addLDAPDescriptions(packet); err != nil {
return nil, err
}
l.Debug.Log("Got response", mlog.Int("id", msgCtx.id), PacketToField(packet), mlog.Err(err))
}
if err != nil {
return nil, err
}
if packet == nil {
return nil, NewError(ErrorNetwork, errCouldNotRetMsg)
}
return packet, nil
}

421
vendor/github.com/mattermost/ldap/search.go generated vendored Normal file
View File

@ -0,0 +1,421 @@
// File contains Search functionality
//
// https://tools.ietf.org/html/rfc4511
//
// SearchRequest ::= [APPLICATION 3] SEQUENCE {
// baseObject LDAPDN,
// scope ENUMERATED {
// baseObject (0),
// singleLevel (1),
// wholeSubtree (2),
// ... },
// derefAliases ENUMERATED {
// neverDerefAliases (0),
// derefInSearching (1),
// derefFindingBaseObj (2),
// derefAlways (3) },
// sizeLimit INTEGER (0 .. maxInt),
// timeLimit INTEGER (0 .. maxInt),
// typesOnly BOOLEAN,
// filter Filter,
// attributes AttributeSelection }
//
// AttributeSelection ::= SEQUENCE OF selector LDAPString
// -- The LDAPString is constrained to
// -- <attributeSelector> in Section 4.5.1.8
//
// Filter ::= CHOICE {
// and [0] SET SIZE (1..MAX) OF filter Filter,
// or [1] SET SIZE (1..MAX) OF filter Filter,
// not [2] Filter,
// equalityMatch [3] AttributeValueAssertion,
// substrings [4] SubstringFilter,
// greaterOrEqual [5] AttributeValueAssertion,
// lessOrEqual [6] AttributeValueAssertion,
// present [7] AttributeDescription,
// approxMatch [8] AttributeValueAssertion,
// extensibleMatch [9] MatchingRuleAssertion,
// ... }
//
// SubstringFilter ::= SEQUENCE {
// type AttributeDescription,
// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE {
// initial [0] AssertionValue, -- can occur at most once
// any [1] AssertionValue,
// final [2] AssertionValue } -- can occur at most once
// }
//
// MatchingRuleAssertion ::= SEQUENCE {
// matchingRule [1] MatchingRuleId OPTIONAL,
// type [2] AttributeDescription OPTIONAL,
// matchValue [3] AssertionValue,
// dnAttributes [4] BOOLEAN DEFAULT FALSE }
//
//
package ldap
import (
"errors"
"fmt"
"sort"
"strings"
ber "github.com/go-asn1-ber/asn1-ber"
)
// scope choices
const (
ScopeBaseObject = 0
ScopeSingleLevel = 1
ScopeWholeSubtree = 2
)
// ScopeMap contains human readable descriptions of scope choices
var ScopeMap = map[int]string{
ScopeBaseObject: "Base Object",
ScopeSingleLevel: "Single Level",
ScopeWholeSubtree: "Whole Subtree",
}
// derefAliases
const (
NeverDerefAliases = 0
DerefInSearching = 1
DerefFindingBaseObj = 2
DerefAlways = 3
)
// DerefMap contains human readable descriptions of derefAliases choices
var DerefMap = map[int]string{
NeverDerefAliases: "NeverDerefAliases",
DerefInSearching: "DerefInSearching",
DerefFindingBaseObj: "DerefFindingBaseObj",
DerefAlways: "DerefAlways",
}
// NewEntry returns an Entry object with the specified distinguished name and attribute key-value pairs.
// The map of attributes is accessed in alphabetical order of the keys in order to ensure that, for the
// same input map of attributes, the output entry will contain the same order of attributes
func NewEntry(dn string, attributes map[string][]string) *Entry {
var attributeNames []string
for attributeName := range attributes {
attributeNames = append(attributeNames, attributeName)
}
sort.Strings(attributeNames)
var encodedAttributes []*EntryAttribute
for _, attributeName := range attributeNames {
encodedAttributes = append(encodedAttributes, NewEntryAttribute(attributeName, attributes[attributeName]))
}
return &Entry{
DN: dn,
Attributes: encodedAttributes,
}
}
// Entry represents a single search result entry
type Entry struct {
// DN is the distinguished name of the entry
DN string
// Attributes are the returned attributes for the entry
Attributes []*EntryAttribute
}
// GetAttributeValues returns the values for the named attribute, or an empty list
func (e *Entry) GetAttributeValues(attribute string) []string {
for _, attr := range e.Attributes {
if attr.Name == attribute {
return attr.Values
}
}
return []string{}
}
// GetRawAttributeValues returns the byte values for the named attribute, or an empty list
func (e *Entry) GetRawAttributeValues(attribute string) [][]byte {
for _, attr := range e.Attributes {
if attr.Name == attribute {
return attr.ByteValues
}
}
return [][]byte{}
}
// GetAttributeValue returns the first value for the named attribute, or ""
func (e *Entry) GetAttributeValue(attribute string) string {
values := e.GetAttributeValues(attribute)
if len(values) == 0 {
return ""
}
return values[0]
}
// GetRawAttributeValue returns the first value for the named attribute, or an empty slice
func (e *Entry) GetRawAttributeValue(attribute string) []byte {
values := e.GetRawAttributeValues(attribute)
if len(values) == 0 {
return []byte{}
}
return values[0]
}
// Print outputs a human-readable description
func (e *Entry) Print() {
fmt.Printf("DN: %s\n", e.DN)
for _, attr := range e.Attributes {
attr.Print()
}
}
// PrettyPrint outputs a human-readable description indenting
func (e *Entry) PrettyPrint(indent int) {
fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN)
for _, attr := range e.Attributes {
attr.PrettyPrint(indent + 2)
}
}
// NewEntryAttribute returns a new EntryAttribute with the desired key-value pair
func NewEntryAttribute(name string, values []string) *EntryAttribute {
var bytes [][]byte
for _, value := range values {
bytes = append(bytes, []byte(value))
}
return &EntryAttribute{
Name: name,
Values: values,
ByteValues: bytes,
}
}
// EntryAttribute holds a single attribute
type EntryAttribute struct {
// Name is the name of the attribute
Name string
// Values contain the string values of the attribute
Values []string
// ByteValues contain the raw values of the attribute
ByteValues [][]byte
}
// Print outputs a human-readable description
func (e *EntryAttribute) Print() {
fmt.Printf("%s: %s\n", e.Name, e.Values)
}
// PrettyPrint outputs a human-readable description with indenting
func (e *EntryAttribute) PrettyPrint(indent int) {
fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values)
}
// SearchResult holds the server's response to a search request
type SearchResult struct {
// Entries are the returned entries
Entries []*Entry
// Referrals are the returned referrals
Referrals []string
// Controls are the returned controls
Controls []Control
}
// Print outputs a human-readable description
func (s *SearchResult) Print() {
for _, entry := range s.Entries {
entry.Print()
}
}
// PrettyPrint outputs a human-readable description with indenting
func (s *SearchResult) PrettyPrint(indent int) {
for _, entry := range s.Entries {
entry.PrettyPrint(indent)
}
}
// SearchRequest represents a search request to send to the server
type SearchRequest struct {
BaseDN string
Scope int
DerefAliases int
SizeLimit int
TimeLimit int
TypesOnly bool
Filter string
Attributes []string
Controls []Control
}
func (req *SearchRequest) appendTo(envelope *ber.Packet) error {
pkt := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request")
pkt.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, req.BaseDN, "Base DN"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.Scope), "Scope"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(req.DerefAliases), "Deref Aliases"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.SizeLimit), "Size Limit"))
pkt.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, uint64(req.TimeLimit), "Time Limit"))
pkt.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimitive, ber.TagBoolean, req.TypesOnly, "Types Only"))
// compile and encode filter
filterPacket, err := CompileFilter(req.Filter)
if err != nil {
return err
}
pkt.AppendChild(filterPacket)
// encode attributes
attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes")
for _, attribute := range req.Attributes {
attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
}
pkt.AppendChild(attributesPacket)
envelope.AppendChild(pkt)
if len(req.Controls) > 0 {
envelope.AppendChild(encodeControls(req.Controls))
}
return nil
}
// NewSearchRequest creates a new search request
func NewSearchRequest(
BaseDN string,
Scope, DerefAliases, SizeLimit, TimeLimit int,
TypesOnly bool,
Filter string,
Attributes []string,
Controls []Control,
) *SearchRequest {
return &SearchRequest{
BaseDN: BaseDN,
Scope: Scope,
DerefAliases: DerefAliases,
SizeLimit: SizeLimit,
TimeLimit: TimeLimit,
TypesOnly: TypesOnly,
Filter: Filter,
Attributes: Attributes,
Controls: Controls,
}
}
// SearchWithPaging accepts a search request and desired page size in order to execute LDAP queries to fulfill the
// search request. All paged LDAP query responses will be buffered and the final result will be returned atomically.
// The following four cases are possible given the arguments:
// - given SearchRequest missing a control of type ControlTypePaging: we will add one with the desired paging size
// - given SearchRequest contains a control of type ControlTypePaging that isn't actually a ControlPaging: fail without issuing any queries
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize equal to the size requested: no change to the search request
// - given SearchRequest contains a control of type ControlTypePaging with pagingSize not equal to the size requested: fail without issuing any queries
//
// A requested pagingSize of 0 is interpreted as no limit by LDAP servers.
func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) {
var pagingControl *ControlPaging
control := FindControl(searchRequest.Controls, ControlTypePaging)
if control == nil {
pagingControl = NewControlPaging(pagingSize)
searchRequest.Controls = append(searchRequest.Controls, pagingControl)
} else {
castControl, ok := control.(*ControlPaging)
if !ok {
return nil, fmt.Errorf("expected paging control to be of type *ControlPaging, got %v", control)
}
if castControl.PagingSize != pagingSize {
return nil, fmt.Errorf("paging size given in search request (%d) conflicts with size given in search call (%d)", castControl.PagingSize, pagingSize)
}
pagingControl = castControl
}
searchResult := new(SearchResult)
for {
result, err := l.Search(searchRequest)
if err != nil {
return searchResult, err
}
if result == nil {
return searchResult, NewError(ErrorNetwork, errors.New("ldap: packet not received"))
}
for _, entry := range result.Entries {
searchResult.Entries = append(searchResult.Entries, entry)
}
for _, referral := range result.Referrals {
searchResult.Referrals = append(searchResult.Referrals, referral)
}
for _, control := range result.Controls {
searchResult.Controls = append(searchResult.Controls, control)
}
pagingResult := FindControl(result.Controls, ControlTypePaging)
if pagingResult == nil {
pagingControl = nil
break
}
cookie := pagingResult.(*ControlPaging).Cookie
if len(cookie) == 0 {
pagingControl = nil
break
}
pagingControl.SetCookie(cookie)
}
if pagingControl != nil {
pagingControl.PagingSize = 0
l.Search(searchRequest)
}
return searchResult, nil
}
// Search performs the given search request
func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) {
msgCtx, err := l.doRequest(searchRequest)
if err != nil {
return nil, err
}
defer l.finishMessage(msgCtx)
result := &SearchResult{
Entries: make([]*Entry, 0),
Referrals: make([]string, 0),
Controls: make([]Control, 0)}
for {
packet, err := l.readPacket(msgCtx)
if err != nil {
return nil, err
}
switch packet.Children[1].Tag {
case 4:
entry := new(Entry)
entry.DN = packet.Children[1].Children[0].Value.(string)
for _, child := range packet.Children[1].Children[1].Children {
attr := new(EntryAttribute)
attr.Name = child.Children[0].Value.(string)
for _, value := range child.Children[1].Children {
attr.Values = append(attr.Values, value.Value.(string))
attr.ByteValues = append(attr.ByteValues, value.ByteValue)
}
entry.Attributes = append(entry.Attributes, attr)
}
result.Entries = append(result.Entries, entry)
case 5:
err := GetLDAPError(packet)
if err != nil {
return nil, err
}
if len(packet.Children) == 3 {
for _, child := range packet.Children[2].Children {
decodedChild, err := DecodeControl(child)
if err != nil {
return nil, fmt.Errorf("failed to decode child control: %s", err)
}
result.Controls = append(result.Controls, decodedChild)
}
}
return result, nil
case 19:
result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string))
}
}
}

42
vendor/github.com/mattermost/logr/v2/.gitignore generated vendored Normal file
View File

@ -0,0 +1,42 @@
# Binaries for programs and plugins
*.exe
*.dll
*.so
*.dylib
debug
dynip
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Output of profiler
*.prof
# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
.glide/
# IntelliJ config
.idea
# log files
*.log
# transient directories
vendor
output
build
app
logs
# test apps
test/cmd/testapp1/testapp1
test/cmd/simple/simple
test/cmd/gelf/gelf
# tools
.aider*
!.aider.conf.yml
!.aiderignore

21
vendor/github.com/mattermost/logr/v2/LICENSE generated vendored Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 wiggin77
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

205
vendor/github.com/mattermost/logr/v2/README.md generated vendored Normal file
View File

@ -0,0 +1,205 @@
![Logr_Logo](https://user-images.githubusercontent.com/7295363/200433587-ae9df127-9427-4753-a0a0-85723a216e0e.png)
> A fully asynchronous, contextual logger for Go.
# logr
[![GoDoc](https://godoc.org/github.com/mattermost/logr?status.svg)](http://godoc.org/github.com/mattermost/logr)
[![Report Card](https://goreportcard.com/badge/github.com/mattermost/logr)](https://goreportcard.com/report/github.com/mattermost/logr)
Logr is inspired by [Logrus](https://github.com/sirupsen/logrus) and [Zap](https://github.com/uber-go/zap) but addresses a number of issues:
1. Logr is fully asynchronous, meaning that all formatting and writing is done in the background. Latency sensitive applications benefit from not waiting for logging to complete.
2. Logr provides custom filters which provide more flexibility than Trace, Debug, Info... levels. If you need to temporarily increase verbosity of logging while tracking down a problem you can avoid the fire-hose that typically comes from Debug or Trace by using custom filters.
3. Logr generates much less allocations than Logrus, and is close to Zap in allocations.
## Concepts
<!-- markdownlint-disable MD033 -->
| entity | description |
| ------ | ----------- |
| Logr | Engine instance typically instantiated once; used to configure logging.<br>```lgr,_ := logr.New()```|
| Logger | Provides contextual logging via fields; lightweight, can be created once and accessed globally, or created on demand.<br>```logger := lgr.NewLogger()```<br>```logger2 := logger.With(logr.String("user", "Sam"))```|
| Target | A destination for log items such as console, file, database or just about anything that can be written to. Each target has its own filter/level and formatter, and any number of targets can be added to a Logr. Targets for file, syslog and any io.Writer are built-in and it is easy to create your own. You can also use any [Logrus hooks](https://github.com/sirupsen/logrus/wiki/Hooks) via a simple [adapter](https://github.com/wiggin77/logrus4logr).|
| Filter | Determines which logging calls get written versus filtered out. Also determines which logging calls generate a stack trace.<br>```filter := &logr.StdFilter{Lvl: logr.Warn, Stacktrace: logr.Fatal}```|
| Formatter | Formats the output. Logr includes built-in formatters for JSON and plain text with delimiters. It is easy to create your own formatters or you can also use any [Logrus formatters](https://github.com/sirupsen/logrus#formatters) via a simple [adapter](https://github.com/wiggin77/logrus4logr).<br>```formatter := &format.Plain{Delim: " \| "}```|
## Usage
```go
// Create Logr instance.
lgr,_ := logr.New()
// Create a filter and formatter. Both can be shared by multiple
// targets.
filter := &logr.StdFilter{Lvl: logr.Warn, Stacktrace: logr.Error}
formatter := &formatters.Plain{Delim: " | "}
// WriterTarget outputs to any io.Writer
t := targets.NewWriterTarget(filter, formatter, os.StdOut, 1000)
lgr.AddTarget(t)
// One or more Loggers can be created, shared, used concurrently,
// or created on demand.
logger := lgr.NewLogger().With("user", "Sarah")
// Now we can log to the target(s).
logger.Debug("login attempt")
logger.Error("login failed")
// Ensure targets are drained before application exit.
lgr.Shutdown()
```
## Fields
Fields allow for contextual logging, meaning information can be added to log statements without changing the statements themselves. Information can be shared across multiple logging statements thus allowing log analysis tools to group them.
Fields can be added to a Logger via `Logger.With` or included with each log record:
```go
lgr,_ := logr.New()
// ... add targets ...
logger := lgr.NewLogger().With(
logr.Any("user": user),
logr.String("role", role)
)
logger.Info("login attempt", logr.Int("attempt_count", count))
// ... later ...
logger.Info("login", logr.String("result", result))
```
Logr fields are inspired by and work the same as [Zap fields](https://pkg.go.dev/go.uber.org/zap#Field).
## Filters
Logr supports the traditional seven log levels via `logr.StdFilter`: Panic, Fatal, Error, Warning, Info, Debug, and Trace.
```go
// When added to a target, this filter will only allow
// log statements with level severity Warn or higher.
// It will also generate stack traces for Error or higher.
filter := &logr.StdFilter{Lvl: logr.Warn, Stacktrace: logr.Error}
```
Logr also supports custom filters (logr.CustomFilter) which allow fine grained inclusion of log items without turning on the fire-hose.
```go
// create custom levels; use IDs > 10.
LoginLevel := logr.Level{ID: 100, Name: "login ", Stacktrace: false}
LogoutLevel := logr.Level{ID: 101, Name: "logout", Stacktrace: false}
lgr,_ := logr.New()
// create a custom filter with custom levels.
filter := &logr.CustomFilter{}
filter.Add(LoginLevel, LogoutLevel)
formatter := &formatters.Plain{Delim: " | "}
tgr := targets.NewWriterTarget(filter, formatter, os.StdOut, 1000)
lgr.AddTarget(tgr)
logger := lgr.NewLogger().With(logr.String("user": "Bob"), logr.String("role": "admin"))
logger.Log(LoginLevel, "this item will get logged")
logger.Debug("won't be logged since Debug wasn't added to custom filter")
```
Both filter types allow you to determine which levels force a stack trace to be output. Note that generating stack traces cannot happen fully asynchronously and thus add some latency to the calling goroutine.
## Targets
There are built-in targets for outputting to syslog, file, TCP, or any `io.Writer`. More will be added.
You can use any [Logrus hooks](https://github.com/sirupsen/logrus/wiki/Hooks) via a simple [adapter](https://github.com/wiggin77/logrus4logr).
You can create your own target by implementing the simple [Target](./target.go) interface.
Example target that outputs to `io.Writer`:
```go
type Writer struct {
out io.Writer
}
func NewWriterTarget(out io.Writer) *Writer {
w := &Writer{out: out}
return w
}
// Called once to initialize target.
func (w *Writer) Init() error {
return nil
}
// Write will always be called by a single internal Logr goroutine, so no locking needed.
func (w *Writer) Write(p []byte, rec *logr.LogRec) (int, error) {
return w.out.Write(buf.Bytes())
}
// Called once to cleanup/free resources for target.
func (w *Writer) Shutdown() error {
return nil
}
```
## Formatters
Logr has two built-in formatters, one for JSON and the other plain, delimited text.
You can use any [Logrus formatters](https://github.com/sirupsen/logrus#formatters) via a simple [adapter](https://github.com/wiggin77/logrus4logr).
You can create your own formatter by implementing the [Formatter](./formatter.go) interface:
```go
Format(rec *LogRec, stacktrace bool, buf *bytes.Buffer) (*bytes.Buffer, error)
```
## Configuration options
When creating the Logr instance, you can set configuration options. For example:
```go
lgr, err := logr.New(
logr.MaxQueueSize(1000),
logr.StackFilter("mypackage1", "mypackage2"),
)
```
Some options are documented below. See [options.go](./options.go) for all available configuration options.
### ```Logr.OnLoggerError(err error)```
Called any time an internal logging error occurs. For example, this can happen when a target cannot connect to its data sink.
It may be tempting to log this error, however there is a danger that logging this will simply generate another error and so on. If you must log it, use a target and custom level specifically for this event and ensure it cannot generate more errors.
### ```Logr.OnQueueFull func(rec *LogRec, maxQueueSize int) bool```
Called on an attempt to add a log record to a full Logr queue. This generally means the Logr maximum queue size is too small, or at least one target is very slow. Logr maximum queue size can be changed before adding any targets via:
```go
lgr, err := logr.New(logr.MaxQueueSize(2000))
```
Returning true will drop the log record. False will block until the log record can be added, which creates a natural throttle at the expense of latency for the calling goroutine. The default is to block.
### ```Logr.OnTargetQueueFull func(target Target, rec *LogRec, maxQueueSize int) bool```
Called on an attempt to add a log record to a full target queue. This generally means your target's max queue size is too small, or the target is very slow to output.
As with the Logr queue, returning true will drop the log record. False will block until the log record can be added, which creates a natural throttle at the expense of latency for the calling goroutine. The default is to block.
### ```Logr.OnExit func(code int) and Logr.OnPanic func(err interface{})```
OnExit and OnPanic are called when the Logger.FatalXXX and Logger.PanicXXX functions are called respectively.
In both cases the default behavior is to shut down gracefully, draining all targets, and calling `os.Exit` or `panic` respectively.
When adding your own handlers, be sure to call `Logr.Shutdown` before exiting the application to avoid losing log records.
### ```Logr.StackFilter(pkg ...string)```
StackFilter sets a list of package names to exclude from the top of stack traces. The `Logr` packages are automatically filtered.

28
vendor/github.com/mattermost/logr/v2/buffer.go generated vendored Normal file
View File

@ -0,0 +1,28 @@
package logr
import (
"bytes"
"sync"
)
// Buffer provides a thread-safe buffer useful for logging to memory in unit tests.
type Buffer struct {
buf bytes.Buffer
mux sync.Mutex
}
func (b *Buffer) Read(p []byte) (n int, err error) {
b.mux.Lock()
defer b.mux.Unlock()
return b.buf.Read(p)
}
func (b *Buffer) Write(p []byte) (n int, err error) {
b.mux.Lock()
defer b.mux.Unlock()
return b.buf.Write(p)
}
func (b *Buffer) String() string {
b.mux.Lock()
defer b.mux.Unlock()
return b.buf.String()
}

209
vendor/github.com/mattermost/logr/v2/config/config.go generated vendored Normal file
View File

@ -0,0 +1,209 @@
package config
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/mattermost/logr/v2"
"github.com/mattermost/logr/v2/formatters"
"github.com/mattermost/logr/v2/targets"
)
type TargetCfg struct {
Type string `json:"type"` // one of "console", "file", "tcp", "syslog", "none".
Options json.RawMessage `json:"options,omitempty"`
Format string `json:"format"` // one of "json", "plain", "gelf"
FormatOptions json.RawMessage `json:"format_options,omitempty"`
Levels []logr.Level `json:"levels"`
MaxQueueSize int `json:"maxqueuesize,omitempty"`
}
type ConsoleOptions struct {
Out string `json:"out"` // one of "stdout", "stderr"
}
type TargetFactory func(targetType string, options json.RawMessage) (logr.Target, error)
type FormatterFactory func(format string, options json.RawMessage) (logr.Formatter, error)
type Factories struct {
TargetFactory TargetFactory // can be nil
FormatterFactory FormatterFactory // can be nil
}
var removeAll = func(ti logr.TargetInfo) bool { return true }
// ConfigureTargets replaces the current list of log targets with a new one based on a map
// of name->TargetCfg. The map of TargetCfg's would typically be serialized from a JSON
// source or can be programmatically created.
//
// An optional set of factories can be provided which will be called to create any target
// types or formatters not built-in.
//
// To append log targets to an existing config, use `(*Logr).AddTarget` or
// `(*Logr).AddTargetFromConfig` instead.
func ConfigureTargets(lgr *logr.Logr, config map[string]TargetCfg, factories *Factories) error {
if err := lgr.RemoveTargets(context.Background(), removeAll); err != nil {
return fmt.Errorf("error removing existing log targets: %w", err)
}
if factories == nil {
factories = &Factories{nil, nil}
}
for name, tcfg := range config {
target, err := newTarget(tcfg.Type, tcfg.Options, factories.TargetFactory)
if err != nil {
return fmt.Errorf("error creating log target %s: %w", name, err)
}
if target == nil {
continue
}
formatter, err := newFormatter(tcfg.Format, tcfg.FormatOptions, factories.FormatterFactory)
if err != nil {
return fmt.Errorf("error creating formatter for log target %s: %w", name, err)
}
filter := newFilter(tcfg.Levels)
qSize := tcfg.MaxQueueSize
if qSize == 0 {
qSize = logr.DefaultMaxQueueSize
}
if err = lgr.AddTarget(target, name, filter, formatter, qSize); err != nil {
return fmt.Errorf("error adding log target %s: %w", name, err)
}
}
return nil
}
func newFilter(levels []logr.Level) logr.Filter {
filter := &logr.CustomFilter{}
for _, lvl := range levels {
filter.Add(lvl)
}
return filter
}
func newTarget(targetType string, options json.RawMessage, factory TargetFactory) (logr.Target, error) {
switch strings.ToLower(targetType) {
case "console":
c := ConsoleOptions{}
if len(options) != 0 {
if err := json.Unmarshal(options, &c); err != nil {
return nil, fmt.Errorf("error decoding console target options: %w", err)
}
}
var w io.Writer
switch c.Out {
case "stderr":
w = os.Stderr
case "stdout", "":
w = os.Stdout
default:
return nil, fmt.Errorf("invalid console target option '%s'", c.Out)
}
return targets.NewWriterTarget(w), nil
case "file":
fo := targets.FileOptions{}
if len(options) == 0 {
return nil, errors.New("missing file target options")
}
if err := json.Unmarshal(options, &fo); err != nil {
return nil, fmt.Errorf("error decoding file target options: %w", err)
}
if err := fo.CheckValid(); err != nil {
return nil, fmt.Errorf("invalid file target options: %w", err)
}
return targets.NewFileTarget(fo), nil
case "tcp":
to := targets.TcpOptions{}
if len(options) == 0 {
return nil, errors.New("missing TCP target options")
}
if err := json.Unmarshal(options, &to); err != nil {
return nil, fmt.Errorf("error decoding TCP target options: %w", err)
}
if err := to.CheckValid(); err != nil {
return nil, fmt.Errorf("invalid TCP target options: %w", err)
}
return targets.NewTcpTarget(&to), nil
case "syslog":
so := targets.SyslogOptions{}
if len(options) == 0 {
return nil, errors.New("missing SysLog target options")
}
if err := json.Unmarshal(options, &so); err != nil {
return nil, fmt.Errorf("error decoding Syslog target options: %w", err)
}
if err := so.CheckValid(); err != nil {
return nil, fmt.Errorf("invalid SysLog target options: %w", err)
}
return targets.NewSyslogTarget(&so)
case "none":
return nil, nil
default:
if factory != nil {
t, err := factory(targetType, options)
if err != nil || t == nil {
return nil, fmt.Errorf("error from target factory: %w", err)
}
return t, nil
}
}
return nil, fmt.Errorf("target type '%s' is unrecognized", targetType)
}
func newFormatter(format string, options json.RawMessage, factory FormatterFactory) (logr.Formatter, error) {
switch strings.ToLower(format) {
case "json":
j := formatters.JSON{}
if len(options) != 0 {
if err := json.Unmarshal(options, &j); err != nil {
return nil, fmt.Errorf("error decoding JSON formatter options: %w", err)
}
if err := j.CheckValid(); err != nil {
return nil, fmt.Errorf("invalid JSON formatter options: %w", err)
}
}
return &j, nil
case "plain":
p := formatters.Plain{}
if len(options) != 0 {
if err := json.Unmarshal(options, &p); err != nil {
return nil, fmt.Errorf("error decoding Plain formatter options: %w", err)
}
if err := p.CheckValid(); err != nil {
return nil, fmt.Errorf("invalid plain formatter options: %w", err)
}
}
return &p, nil
case "gelf":
g := formatters.Gelf{}
if len(options) != 0 {
if err := json.Unmarshal(options, &g); err != nil {
return nil, fmt.Errorf("error decoding Gelf formatter options: %w", err)
}
if err := g.CheckValid(); err != nil {
return nil, fmt.Errorf("invalid GELF formatter options: %w", err)
}
}
return &g, nil
default:
if factory != nil {
f, err := factory(format, options)
if err != nil || f == nil {
return nil, fmt.Errorf("error from formatter factory: %w", err)
}
return f, nil
}
}
return nil, fmt.Errorf("format '%s' is unrecognized", format)
}

View File

@ -0,0 +1,90 @@
{
"sample-console": {
"type": "console",
"options": {
"out": "stdout"
},
"format": "plain",
"format_options": {
"delim": " | "
},
"levels": [
{"id": 5, "name": "debug"},
{"id": 4, "name": "info"},
{"id": 3, "name": "warn"},
{"id": 2, "name": "error", "stacktrace": true},
{"id": 1, "name": "fatal", "stacktrace": true},
{"id": 0, "name": "panic", "stacktrace": true}
],
"maxqueuesize": 1000
},
"sample-file": {
"type": "file",
"options": {
"filename": "test.log",
"max_size": 1000000,
"max_age": 1,
"max_backups": 10,
"compress": true
},
"format": "json",
"format_options": {
},
"levels": [
{"id": 5, "name": "debug"},
{"id": 4, "name": "info"},
{"id": 3, "name": "warn"},
{"id": 2, "name": "error", "stacktrace": true},
{"id": 1, "name": "fatal", "stacktrace": true},
{"id": 0, "name": "panic", "stacktrace": true}
],
"maxqueuesize": 1000
},
"sample-tcp": {
"type": "tcp",
"options": {
"host": "localhost",
"port": 18066,
"tls": false,
"cert": "",
"insecure": false
},
"format": "gelf",
"format_options": {
"hostname": "server01"
},
"levels": [
{"id": 5, "name": "debug"},
{"id": 4, "name": "info"},
{"id": 3, "name": "warn"},
{"id": 2, "name": "error", "stacktrace": true},
{"id": 1, "name": "fatal", "stacktrace": true},
{"id": 0, "name": "panic", "stacktrace": true}
],
"maxqueuesize": 1000
},
"sample-syslog": {
"type": "syslog",
"options": {
"host": "localhost",
"port": 18066,
"tls": false,
"cert": "",
"insecure": false,
"tag": "testapp"
},
"format": "plain",
"format_options": {
"delim": " "
},
"levels": [
{"id": 5, "name": "debug"},
{"id": 4, "name": "info"},
{"id": 3, "name": "warn"},
{"id": 2, "name": "error", "stacktrace": true},
{"id": 1, "name": "fatal", "stacktrace": true},
{"id": 0, "name": "panic", "stacktrace": true}
],
"maxqueuesize": 1000
}
}

37
vendor/github.com/mattermost/logr/v2/const.go generated vendored Normal file
View File

@ -0,0 +1,37 @@
package logr
import "time"
// Defaults.
const (
// DefaultMaxQueueSize is the default maximum queue size for Logr instances.
DefaultMaxQueueSize = 1000
// DefaultMaxStackFrames is the default maximum max number of stack frames collected
// when generating stack traces for logging.
DefaultMaxStackFrames = 30
// MaxLevelID is the maximum value of a level ID. Some level cache implementations will
// allocate a cache of this size. Cannot exceed uint.
MaxLevelID = 65535
// DefaultEnqueueTimeout is the default amount of time a log record can take to be queued.
// This only applies to blocking enqueue which happen after `logr.OnQueueFull` is called
// and returns false.
DefaultEnqueueTimeout = time.Second * 30
// DefaultShutdownTimeout is the default amount of time `logr.Shutdown` can execute before
// timing out.
DefaultShutdownTimeout = time.Second * 30
// DefaultFlushTimeout is the default amount of time `logr.Flush` can execute before
// timing out.
DefaultFlushTimeout = time.Second * 30
// DefaultMaxPooledBuffer is the maximum size a pooled buffer can be.
// Buffers that grow beyond this size are garbage collected.
DefaultMaxPooledBuffer = 1024 * 1024
// DefaultMaxFieldLength is the maximum size of a String or fmt.Stringer field can be.
DefaultMaxFieldLength = -1
)

415
vendor/github.com/mattermost/logr/v2/field.go generated vendored Normal file
View File

@ -0,0 +1,415 @@
package logr
import (
"errors"
"fmt"
"io"
"reflect"
"strconv"
"time"
)
var (
Comma = []byte{','}
Equals = []byte{'='}
Space = []byte{' '}
Newline = []byte{'\n'}
Quote = []byte{'"'}
Colon = []byte{':'}
)
// LogCloner is implemented by `Any` types that require a clone to be provided
// to the logger because the original may mutate.
type LogCloner interface {
LogClone() interface{}
}
// LogWriter is implemented by `Any` types that provide custom formatting for
// log output. A string representation of the type should be written directly to
// the `io.Writer`.
type LogWriter interface {
LogWrite(w io.Writer) error
}
type FieldType uint8
const (
UnknownType FieldType = iota
StringType
StringerType
StructType
ErrorType
BoolType
TimestampMillisType
TimeType
DurationType
Int64Type
Int32Type
IntType
Uint64Type
Uint32Type
UintType
Float64Type
Float32Type
BinaryType
ArrayType
MapType
)
type Field struct {
Key string
Type FieldType
Integer int64
Float float64
String string
Interface interface{}
}
func quoteString(w io.Writer, s string, shouldQuote func(s string) bool) error {
b := shouldQuote(s)
if b {
if _, err := w.Write(Quote); err != nil {
return err
}
}
if _, err := w.Write([]byte(s)); err != nil {
return err
}
if b {
if _, err := w.Write(Quote); err != nil {
return err
}
}
return nil
}
// ValueString converts a known type to a string using default formatting.
// This is called lazily by a formatter.
// Formatters can provide custom formatting or types passed via `Any` can implement
// the `LogString` interface to generate output for logging.
// If the optional shouldQuote callback is provided, then it will be called for any
// string output that could potentially need to be quoted.
func (f Field) ValueString(w io.Writer, shouldQuote func(s string) bool) error {
if shouldQuote == nil {
shouldQuote = func(s string) bool { return false }
}
var err error
switch f.Type {
case StringType:
err = quoteString(w, f.String, shouldQuote)
case StringerType:
s, ok := f.Interface.(fmt.Stringer)
if ok {
err = quoteString(w, s.String(), shouldQuote)
} else if f.Interface == nil {
err = quoteString(w, "", shouldQuote)
} else {
err = fmt.Errorf("invalid fmt.Stringer for key %s", f.Key)
}
case StructType:
s, ok := f.Interface.(LogWriter)
if ok {
err = s.LogWrite(w)
break
}
// structs that do not implement LogWriter fall back to reflection via Printf.
// TODO: create custom reflection-based encoder.
_, err = fmt.Fprintf(w, "%v", f.Interface)
case ErrorType:
// TODO: create custom error encoder.
err = quoteString(w, fmt.Sprintf("%v", f.Interface), shouldQuote)
case BoolType:
var b bool
if f.Integer != 0 {
b = true
}
_, err = io.WriteString(w, strconv.FormatBool(b))
case TimestampMillisType:
ts := time.Unix(f.Integer/1000, (f.Integer%1000)*int64(time.Millisecond))
err = quoteString(w, ts.UTC().Format(TimestampMillisFormat), shouldQuote)
case TimeType:
t, ok := f.Interface.(time.Time)
if !ok {
err = errors.New("invalid time")
break
}
err = quoteString(w, t.Format(DefTimestampFormat), shouldQuote)
case DurationType:
_, err = fmt.Fprintf(w, "%s", time.Duration(f.Integer))
case Int64Type, Int32Type, IntType:
_, err = io.WriteString(w, strconv.FormatInt(f.Integer, 10))
case Uint64Type, Uint32Type, UintType:
_, err = io.WriteString(w, strconv.FormatUint(uint64(f.Integer), 10))
case Float64Type, Float32Type:
size := 64
if f.Type == Float32Type {
size = 32
}
err = quoteString(w, strconv.FormatFloat(f.Float, 'f', -1, size), shouldQuote)
case BinaryType:
b, ok := f.Interface.([]byte)
if ok {
_, err = fmt.Fprintf(w, "[%X]", b)
break
}
_, err = fmt.Fprintf(w, "[%v]", f.Interface)
case ArrayType:
a := reflect.ValueOf(f.Interface)
arr:
for i := 0; i < a.Len(); i++ {
item := a.Index(i)
switch v := item.Interface().(type) {
case LogWriter:
if err = v.LogWrite(w); err != nil {
break arr
}
case fmt.Stringer:
if err = quoteString(w, v.String(), shouldQuote); err != nil {
break arr
}
default:
s := fmt.Sprintf("%v", v)
if err = quoteString(w, s, shouldQuote); err != nil {
break arr
}
}
if i != a.Len()-1 {
if _, err = w.Write(Comma); err != nil {
break arr
}
}
}
case MapType:
a := reflect.ValueOf(f.Interface)
iter := a.MapRange()
// Already advance to first element
if !iter.Next() {
return nil
}
it:
for {
if _, err = io.WriteString(w, iter.Key().String()); err != nil {
break it
}
if _, err = w.Write(Equals); err != nil {
break it
}
val := iter.Value().Interface()
switch v := val.(type) {
case LogWriter:
if err = v.LogWrite(w); err != nil {
break it
}
case fmt.Stringer:
if err = quoteString(w, v.String(), shouldQuote); err != nil {
break it
}
default:
s := fmt.Sprintf("%v", v)
if err = quoteString(w, s, shouldQuote); err != nil {
break it
}
}
if !iter.Next() {
break it
}
if _, err = w.Write(Comma); err != nil {
break it
}
}
case UnknownType:
_, err = fmt.Fprintf(w, "%v", f.Interface)
default:
err = fmt.Errorf("invalid type %d", f.Type)
}
return err
}
func nilField(key string) Field {
return String(key, "")
}
func fieldForAny(key string, val interface{}) Field {
switch v := val.(type) {
case LogCloner:
if v == nil {
return nilField(key)
}
c := v.LogClone()
return Field{Key: key, Type: StructType, Interface: c}
case *LogCloner:
if v == nil {
return nilField(key)
}
c := (*v).LogClone()
return Field{Key: key, Type: StructType, Interface: c}
case LogWriter:
if v == nil {
return nilField(key)
}
return Field{Key: key, Type: StructType, Interface: v}
case *LogWriter:
if v == nil {
return nilField(key)
}
return Field{Key: key, Type: StructType, Interface: *v}
case bool:
return Bool(key, v)
case *bool:
if v == nil {
return nilField(key)
}
return Bool(key, *v)
case float64:
return Float(key, v)
case *float64:
if v == nil {
return nilField(key)
}
return Float(key, *v)
case float32:
return Float(key, v)
case *float32:
if v == nil {
return nilField(key)
}
return Float(key, *v)
case int:
return Int(key, v)
case *int:
if v == nil {
return nilField(key)
}
return Int(key, *v)
case int64:
return Int(key, v)
case *int64:
if v == nil {
return nilField(key)
}
return Int(key, *v)
case int32:
return Int(key, v)
case *int32:
if v == nil {
return nilField(key)
}
return Int(key, *v)
case int16:
return Int(key, int32(v))
case *int16:
if v == nil {
return nilField(key)
}
return Int(key, int32(*v))
case int8:
return Int(key, int32(v))
case *int8:
if v == nil {
return nilField(key)
}
return Int(key, int32(*v))
case string:
return String(key, v)
case *string:
if v == nil {
return nilField(key)
}
return String(key, *v)
case uint:
return Uint(key, v)
case *uint:
if v == nil {
return nilField(key)
}
return Uint(key, *v)
case uint64:
return Uint(key, v)
case *uint64:
if v == nil {
return nilField(key)
}
return Uint(key, *v)
case uint32:
return Uint(key, v)
case *uint32:
if v == nil {
return nilField(key)
}
return Uint(key, *v)
case uint16:
return Uint(key, uint32(v))
case *uint16:
if v == nil {
return nilField(key)
}
return Uint(key, uint32(*v))
case uint8:
return Uint(key, uint32(v))
case *uint8:
if v == nil {
return nilField(key)
}
return Uint(key, uint32(*v))
case []byte:
if v == nil {
return nilField(key)
}
return Field{Key: key, Type: BinaryType, Interface: v}
case time.Time:
return Time(key, v)
case *time.Time:
if v == nil {
return nilField(key)
}
return Time(key, *v)
case time.Duration:
return Duration(key, v)
case *time.Duration:
if v == nil {
return nilField(key)
}
return Duration(key, *v)
case error:
return NamedErr(key, v)
case fmt.Stringer:
if v == nil {
return nilField(key)
}
return Field{Key: key, Type: StringerType, Interface: v}
case *fmt.Stringer:
if v == nil {
return nilField(key)
}
return Field{Key: key, Type: StringerType, Interface: *v}
default:
return Field{Key: key, Type: UnknownType, Interface: val}
}
}
// FieldSorter provides sorting of an array of fields by key.
type FieldSorter []Field
func (fs FieldSorter) Len() int { return len(fs) }
func (fs FieldSorter) Less(i, j int) bool { return fs[i].Key < fs[j].Key }
func (fs FieldSorter) Swap(i, j int) { fs[i], fs[j] = fs[j], fs[i] }

127
vendor/github.com/mattermost/logr/v2/fieldapi.go generated vendored Normal file
View File

@ -0,0 +1,127 @@
package logr
import (
"fmt"
"time"
)
// Any picks the best supported field type based on type of val.
// For best performance when passing a struct (or struct pointer),
// implement `logr.LogWriter` on the struct, otherwise reflection
// will be used to generate a string representation.
func Any(key string, val any) Field {
return fieldForAny(key, val)
}
// Int64 constructs a field containing a key and Int64 value.
//
// Deprecated: Use [logr.Int] instead.
func Int64(key string, val int64) Field {
return Field{Key: key, Type: Int64Type, Integer: val}
}
// Int32 constructs a field containing a key and Int32 value.
//
// Deprecated: Use [logr.Int] instead.
func Int32(key string, val int32) Field {
return Field{Key: key, Type: Int32Type, Integer: int64(val)}
}
// Int constructs a field containing a key and int value.
func Int[T ~int | ~int8 | ~int16 | ~int32 | ~int64](key string, val T) Field {
return Field{Key: key, Type: IntType, Integer: int64(val)}
}
// Uint64 constructs a field containing a key and Uint64 value.
//
// Deprecated: Use [logr.Uint] instead.
func Uint64(key string, val uint64) Field {
return Field{Key: key, Type: Uint64Type, Integer: int64(val)}
}
// Uint32 constructs a field containing a key and Uint32 value.
//
// Deprecated: Use [logr.Uint] instead
func Uint32(key string, val uint32) Field {
return Field{Key: key, Type: Uint32Type, Integer: int64(val)}
}
// Uint constructs a field containing a key and uint value.
func Uint[T ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr](key string, val T) Field {
return Field{Key: key, Type: UintType, Integer: int64(val)}
}
// Float64 constructs a field containing a key and Float64 value.
//
// Deprecated: Use [logr.Float] instead
func Float64(key string, val float64) Field {
return Field{Key: key, Type: Float64Type, Float: val}
}
// Float32 constructs a field containing a key and Float32 value.
//
// Deprecated: Use [logr.Float] instead
func Float32(key string, val float32) Field {
return Field{Key: key, Type: Float32Type, Float: float64(val)}
}
// Float32 constructs a field containing a key and float value.
func Float[T ~float32 | ~float64](key string, val T) Field {
return Field{Key: key, Type: Float32Type, Float: float64(val)}
}
// String constructs a field containing a key and String value.
func String[T ~string | ~[]byte](key string, val T) Field {
return Field{Key: key, Type: StringType, String: string(val)}
}
// Stringer constructs a field containing a key and a `fmt.Stringer` value.
// The `String` method will be called in lazy fashion.
func Stringer(key string, val fmt.Stringer) Field {
return Field{Key: key, Type: StringerType, Interface: val}
}
// Err constructs a field containing a default key ("error") and error value.
func Err(err error) Field {
return NamedErr("error", err)
}
// NamedErr constructs a field containing a key and error value.
func NamedErr(key string, err error) Field {
return Field{Key: key, Type: ErrorType, Interface: err}
}
// Bool constructs a field containing a key and bool value.
func Bool[T ~bool](key string, val T) Field {
var b int64
if val {
b = 1
}
return Field{Key: key, Type: BoolType, Integer: b}
}
// Time constructs a field containing a key and time.Time value.
func Time(key string, val time.Time) Field {
return Field{Key: key, Type: TimeType, Interface: val}
}
// Duration constructs a field containing a key and time.Duration value.
func Duration(key string, val time.Duration) Field {
return Field{Key: key, Type: DurationType, Integer: int64(val)}
}
// Millis constructs a field containing a key and timestamp value.
// The timestamp is expected to be milliseconds since Jan 1, 1970 UTC.
func Millis(key string, val int64) Field {
return Field{Key: key, Type: TimestampMillisType, Integer: val}
}
// Array constructs a field containing a key and array value.
func Array[S ~[]E, E any](key string, val S) Field {
return Field{Key: key, Type: ArrayType, Interface: val}
}
// Map constructs a field containing a key and map value.
func Map[M ~map[K]V, K comparable, V any](key string, val M) Field {
return Field{Key: key, Type: MapType, Interface: val}
}

10
vendor/github.com/mattermost/logr/v2/filter.go generated vendored Normal file
View File

@ -0,0 +1,10 @@
package logr
// Filter allows targets to determine which Level(s) are active
// for logging and which Level(s) require a stack trace to be output.
// A default implementation using "panic, fatal..." is provided, and
// a more flexible alternative implementation is also provided that
// allows any number of custom levels.
type Filter interface {
GetEnabledLevel(level Level) (Level, bool)
}

47
vendor/github.com/mattermost/logr/v2/filtercustom.go generated vendored Normal file
View File

@ -0,0 +1,47 @@
package logr
import (
"sync"
)
// CustomFilter allows targets to enable logging via a list of discrete levels.
type CustomFilter struct {
mux sync.RWMutex
levels map[LevelID]Level
}
// NewCustomFilter creates a filter supporting discrete log levels.
func NewCustomFilter(levels ...Level) *CustomFilter {
filter := &CustomFilter{}
filter.Add(levels...)
return filter
}
// GetEnabledLevel returns the Level with the specified Level.ID and whether the level
// is enabled for this filter.
func (cf *CustomFilter) GetEnabledLevel(level Level) (Level, bool) {
cf.mux.RLock()
defer cf.mux.RUnlock()
levelEnabled, ok := cf.levels[level.ID]
if ok && levelEnabled.Name == "" {
levelEnabled.Name = level.Name
}
return levelEnabled, ok
}
// Add adds one or more levels to the list. Adding a level enables logging for
// that level on any targets using this CustomFilter.
func (cf *CustomFilter) Add(levels ...Level) {
cf.mux.Lock()
defer cf.mux.Unlock()
if cf.levels == nil {
cf.levels = make(map[LevelID]Level)
}
for _, s := range levels {
cf.levels[s.ID] = s
}
}

71
vendor/github.com/mattermost/logr/v2/filterstd.go generated vendored Normal file
View File

@ -0,0 +1,71 @@
package logr
// StdFilter allows targets to filter via classic log levels where any level
// beyond a certain verbosity/severity is enabled.
type StdFilter struct {
Lvl Level
Stacktrace Level
}
// GetEnabledLevel returns the Level with the specified Level.ID and whether the level
// is enabled for this filter.
func (lt StdFilter) GetEnabledLevel(level Level) (Level, bool) {
enabled := level.ID <= lt.Lvl.ID
stackTrace := level.ID <= lt.Stacktrace.ID
var levelEnabled Level
if enabled {
switch level.ID {
case Panic.ID:
levelEnabled = Panic
case Fatal.ID:
levelEnabled = Fatal
case Error.ID:
levelEnabled = Error
case Warn.ID:
levelEnabled = Warn
case Info.ID:
levelEnabled = Info
case Debug.ID:
levelEnabled = Debug
case Trace.ID:
levelEnabled = Trace
default:
levelEnabled = level
}
}
if stackTrace {
levelEnabled.Stacktrace = true
}
return levelEnabled, enabled
}
// IsEnabled returns true if the specified Level is at or above this verbosity. Also
// determines if a stack trace is required.
func (lt StdFilter) IsEnabled(level Level) bool {
return level.ID <= lt.Lvl.ID
}
// IsStacktraceEnabled returns true if the specified Level requires a stack trace.
func (lt StdFilter) IsStacktraceEnabled(level Level) bool {
return level.ID <= lt.Stacktrace.ID
}
var (
// Panic is the highest level of severity.
Panic = Level{ID: 0, Name: "panic", Color: Red}
// Fatal designates a catastrophic error.
Fatal = Level{ID: 1, Name: "fatal", Color: Red}
// Error designates a serious but possibly recoverable error.
Error = Level{ID: 2, Name: "error", Color: Red}
// Warn designates non-critical error.
Warn = Level{ID: 3, Name: "warn", Color: Yellow}
// Info designates information regarding application events.
Info = Level{ID: 4, Name: "info", Color: Cyan}
// Debug designates verbose information typically used for debugging.
Debug = Level{ID: 5, Name: "debug", Color: NoColor}
// Trace designates the highest verbosity of log output.
Trace = Level{ID: 6, Name: "trace", Color: NoColor}
)

210
vendor/github.com/mattermost/logr/v2/formatter.go generated vendored Normal file
View File

@ -0,0 +1,210 @@
package logr
import (
"bytes"
"fmt"
"io"
"runtime"
"strconv"
)
// Formatter turns a LogRec into a formatted string.
type Formatter interface {
// IsStacktraceNeeded returns true if this formatter requires a stacktrace to be
// generated for each LogRecord. Enabling features such as `Caller` field require
// a stacktrace.
IsStacktraceNeeded() bool
// Format converts a log record to bytes. If buf is not nil then it will be
// be filled with the formatted results, otherwise a new buffer will be allocated.
Format(rec *LogRec, level Level, buf *bytes.Buffer) (*bytes.Buffer, error)
}
const (
// DefTimestampFormat is the default time stamp format used by Plain formatter and others.
DefTimestampFormat = "2006-01-02 15:04:05.000 Z07:00"
// TimestampMillisFormat is the format for logging milliseconds UTC
TimestampMillisFormat = "Jan _2 15:04:05.000"
)
// LimitByteSlice discards the bytes from a slice that exceeds the limit
func LimitByteSlice(b []byte, limit int) []byte {
if limit > 0 && limit < len(b) {
lb := make([]byte, limit, limit+3)
copy(lb, b[:limit])
return append(lb, []byte("...")...)
}
return b
}
// LimitString discards the runes from a slice that exceeds the limit
func LimitString(b string, limit int) string {
return string(LimitByteSlice([]byte(b), limit))
}
type LimitedStringer struct {
fmt.Stringer
Limit int
}
func (ls *LimitedStringer) String() string {
return LimitString(ls.Stringer.String(), ls.Limit)
}
type Writer struct {
io.Writer
}
func (w Writer) Writes(elems ...[]byte) (int, error) {
var count int
for _, e := range elems {
if c, err := w.Write(e); err != nil {
return count + c, err
} else {
count += c
}
}
return count, nil
}
// DefaultFormatter is the default formatter, outputting only text with
// no colors and a space delimiter. Use `format.Plain` instead.
type DefaultFormatter struct {
}
// IsStacktraceNeeded always returns false for default formatter since the
// `Caller` field is not supported.
func (p *DefaultFormatter) IsStacktraceNeeded() bool {
return false
}
// Format converts a log record to bytes.
func (p *DefaultFormatter) Format(rec *LogRec, level Level, buf *bytes.Buffer) (*bytes.Buffer, error) {
if buf == nil {
buf = &bytes.Buffer{}
}
timestampFmt := DefTimestampFormat
buf.WriteString(rec.Time().Format(timestampFmt))
buf.Write(Space)
buf.WriteString(level.Name)
buf.Write(Space)
buf.WriteString(rec.Msg())
buf.Write(Space)
fields := rec.Fields()
if len(fields) > 0 {
if err := WriteFields(buf, fields, Space, NoColor); err != nil {
return nil, err
}
}
if level.Stacktrace {
frames := rec.StackFrames()
if len(frames) > 0 {
buf.Write(Newline)
if err := WriteStacktrace(buf, rec.StackFrames()); err != nil {
return nil, err
}
}
}
buf.Write(Newline)
return buf, nil
}
// WriteFields writes zero or more name value pairs to the io.Writer.
// The pairs output in key=value format with optional separator between fields.
func WriteFields(w io.Writer, fields []Field, separator []byte, color Color) error {
ws := Writer{w}
sep := []byte{}
for _, field := range fields {
if err := writeField(ws, field, sep, color); err != nil {
return err
}
sep = separator
}
return nil
}
func writeField(ws Writer, field Field, sep []byte, color Color) error {
if len(sep) != 0 {
if _, err := ws.Write(sep); err != nil {
return err
}
}
if err := WriteWithColor(ws, field.Key, color); err != nil {
return err
}
if _, err := ws.Write(Equals); err != nil {
return err
}
return field.ValueString(ws, shouldQuote)
}
// shouldQuote returns true if val contains any characters that might be unsafe
// when injecting log output into an aggregator, viewer or report.
func shouldQuote(val string) bool {
for _, c := range val {
if !((c >= '0' && c <= '9') ||
(c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
c == '-' || c == '.' || c == '_' || c == '/' || c == '@' || c == '^' || c == '+') {
return true
}
}
return false
}
// WriteStacktrace formats and outputs a stack trace to an io.Writer.
func WriteStacktrace(w io.Writer, frames []runtime.Frame) error {
ws := Writer{w}
for _, frame := range frames {
if frame.Function != "" {
if _, err := ws.Writes(Space, Space, []byte(frame.Function), Newline); err != nil {
return err
}
}
if frame.File != "" {
s := strconv.FormatInt(int64(frame.Line), 10)
if _, err := ws.Writes([]byte{' ', ' ', ' ', ' ', ' ', ' '}, []byte(frame.File), Colon, []byte(s), Newline); err != nil {
return err
}
}
}
return nil
}
// WriteWithColor outputs a string with the specified ANSI color.
func WriteWithColor(w io.Writer, s string, color Color) error {
var err error
writer := func(buf []byte) {
if err != nil {
return
}
_, err = w.Write(buf)
}
if color != NoColor {
writer(AnsiColorPrefix)
writer([]byte(strconv.FormatInt(int64(color), 10)))
writer(AnsiColorSuffix)
}
if err == nil {
_, err = io.WriteString(w, s)
}
if color != NoColor {
writer(AnsiColorPrefix)
writer([]byte(strconv.FormatInt(int64(NoColor), 10)))
writer(AnsiColorSuffix)
}
return err
}

Some files were not shown because too many files have changed in this diff Show More