// Copyright (c) 2024 Mattermost Community Enterprise // Open source implementation of Mattermost Enterprise LDAP diagnostics package ldap import ( "crypto/tls" "fmt" "net/http" "time" "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/mlog" "github.com/mattermost/mattermost/server/public/shared/request" "github.com/mattermost/mattermost/server/v8/einterfaces" ldapv3 "github.com/go-ldap/ldap/v3" ) type LdapDiagnosticImpl struct { config func() *model.Config logger mlog.LoggerIFace } func NewLdapDiagnosticInterface(config func() *model.Config, logger mlog.LoggerIFace) einterfaces.LdapDiagnosticInterface { return &LdapDiagnosticImpl{ config: config, logger: logger, } } func (ld *LdapDiagnosticImpl) getSettings() *model.LdapSettings { return &ld.config().LdapSettings } // RunTest runs a basic LDAP connection test func (ld *LdapDiagnosticImpl) RunTest(rctx request.CTX) *model.AppError { return ld.RunTestConnection(rctx, *ld.getSettings()) } // GetVendorNameAndVendorVersion retrieves LDAP server vendor info func (ld *LdapDiagnosticImpl) GetVendorNameAndVendorVersion(rctx request.CTX) (string, string, error) { settings := ld.getSettings() conn, err := ld.connect(settings) if err != nil { return "", "", err } defer conn.Close() if err := conn.Bind(*settings.BindUsername, *settings.BindPassword); err != nil { return "", "", err } // Query root DSE for vendor info searchRequest := ldapv3.NewSearchRequest( "", ldapv3.ScopeBaseObject, ldapv3.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{"vendorName", "vendorVersion", "supportedLDAPVersion"}, nil, ) sr, err := conn.Search(searchRequest) if err != nil { return "Unknown", "Unknown", nil } if len(sr.Entries) == 0 { return "Unknown", "Unknown", nil } entry := sr.Entries[0] vendorName := entry.GetAttributeValue("vendorName") vendorVersion := entry.GetAttributeValue("vendorVersion") if vendorName == "" { vendorName = "Unknown" } if vendorVersion == "" { vendorVersion = "Unknown" } return vendorName, vendorVersion, nil } // RunTestConnection tests LDAP connectivity with given settings func (ld *LdapDiagnosticImpl) RunTestConnection(rctx request.CTX, settings model.LdapSettings) *model.AppError { conn, err := ld.connect(&settings) if err != nil { return model.NewAppError("LdapDiagnostic.RunTestConnection", "api.ldap.connection_error.app_error", nil, err.Error(), http.StatusInternalServerError) } defer conn.Close() // Test bind if err := conn.Bind(*settings.BindUsername, *settings.BindPassword); err != nil { return model.NewAppError("LdapDiagnostic.RunTestConnection", "api.ldap.bind_error.app_error", nil, err.Error(), http.StatusUnauthorized) } // Test search searchRequest := ldapv3.NewSearchRequest( *settings.BaseDN, ldapv3.ScopeBaseObject, ldapv3.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{"dn"}, nil, ) _, err = conn.Search(searchRequest) if err != nil { return model.NewAppError("LdapDiagnostic.RunTestConnection", "api.ldap.search_error.app_error", nil, err.Error(), http.StatusInternalServerError) } return nil } // RunTestDiagnostics runs detailed diagnostic tests func (ld *LdapDiagnosticImpl) RunTestDiagnostics(rctx request.CTX, testType model.LdapDiagnosticTestType, settings model.LdapSettings) ([]model.LdapDiagnosticResult, *model.AppError) { conn, err := ld.connect(&settings) if err != nil { return nil, model.NewAppError("LdapDiagnostic.RunTestDiagnostics", "api.ldap.connection_error.app_error", nil, err.Error(), http.StatusInternalServerError) } defer conn.Close() if err := conn.Bind(*settings.BindUsername, *settings.BindPassword); err != nil { return nil, model.NewAppError("LdapDiagnostic.RunTestDiagnostics", "api.ldap.bind_error.app_error", nil, err.Error(), http.StatusUnauthorized) } var results []model.LdapDiagnosticResult switch testType { case model.LdapDiagnosticTestTypeFilters: results = ld.testFilters(conn, &settings) case model.LdapDiagnosticTestTypeAttributes: results = ld.testAttributes(conn, &settings) case model.LdapDiagnosticTestTypeGroupAttributes: results = ld.testGroupAttributes(conn, &settings) default: return nil, model.NewAppError("LdapDiagnostic.RunTestDiagnostics", "api.ldap.invalid_test_type.app_error", nil, "", http.StatusBadRequest) } return results, nil } func (ld *LdapDiagnosticImpl) connect(settings *model.LdapSettings) (*ldapv3.Conn, error) { ldapServer := *settings.LdapServer ldapPort := *settings.LdapPort connectionSecurity := *settings.ConnectionSecurity var conn *ldapv3.Conn var err error address := fmt.Sprintf("%s:%d", ldapServer, ldapPort) switch connectionSecurity { case model.ConnSecurityTLS: tlsConfig := &tls.Config{ InsecureSkipVerify: *settings.SkipCertificateVerification, ServerName: ldapServer, } conn, err = ldapv3.DialTLS("tcp", address, tlsConfig) case model.ConnSecurityStarttls: conn, err = ldapv3.Dial("tcp", address) if err != nil { return nil, err } tlsConfig := &tls.Config{ InsecureSkipVerify: *settings.SkipCertificateVerification, ServerName: ldapServer, } err = conn.StartTLS(tlsConfig) default: conn, err = ldapv3.Dial("tcp", address) } if err != nil { return nil, err } if settings.QueryTimeout != nil && *settings.QueryTimeout > 0 { conn.SetTimeout(time.Duration(*settings.QueryTimeout) * time.Second) } return conn, nil } func (ld *LdapDiagnosticImpl) testFilters(conn *ldapv3.Conn, settings *model.LdapSettings) []model.LdapDiagnosticResult { var results []model.LdapDiagnosticResult // Test user filter userFilter := *settings.UserFilter if userFilter == "" { userFilter = "(objectClass=person)" } userResult := model.LdapDiagnosticResult{ TestName: "User Filter", TestValue: userFilter, } sr, err := conn.Search(ldapv3.NewSearchRequest( *settings.BaseDN, ldapv3.ScopeWholeSubtree, ldapv3.NeverDerefAliases, 100, 0, false, userFilter, []string{"dn"}, nil, )) if err != nil { userResult.Error = err.Error() } else { userResult.TotalCount = len(sr.Entries) userResult.Message = fmt.Sprintf("Found %d users", len(sr.Entries)) // Sample results maxSamples := 5 if len(sr.Entries) < maxSamples { maxSamples = len(sr.Entries) } for i := 0; i < maxSamples; i++ { userResult.SampleResults = append(userResult.SampleResults, model.LdapSampleEntry{ DN: sr.Entries[i].DN, }) } } results = append(results, userResult) // Test group filter if *settings.GroupFilter != "" { groupResult := model.LdapDiagnosticResult{ TestName: "Group Filter", TestValue: *settings.GroupFilter, } sr, err := conn.Search(ldapv3.NewSearchRequest( *settings.BaseDN, ldapv3.ScopeWholeSubtree, ldapv3.NeverDerefAliases, 100, 0, false, *settings.GroupFilter, []string{"dn"}, nil, )) if err != nil { groupResult.Error = err.Error() } else { groupResult.TotalCount = len(sr.Entries) groupResult.Message = fmt.Sprintf("Found %d groups", len(sr.Entries)) } results = append(results, groupResult) } return results } func (ld *LdapDiagnosticImpl) testAttributes(conn *ldapv3.Conn, settings *model.LdapSettings) []model.LdapDiagnosticResult { var results []model.LdapDiagnosticResult userFilter := *settings.UserFilter if userFilter == "" { userFilter = "(objectClass=person)" } // Get sample users sr, err := conn.Search(ldapv3.NewSearchRequest( *settings.BaseDN, ldapv3.ScopeWholeSubtree, ldapv3.NeverDerefAliases, 100, 0, false, userFilter, []string{"*"}, nil, )) if err != nil { return []model.LdapDiagnosticResult{{ TestName: "Attributes", Error: err.Error(), }} } totalUsers := len(sr.Entries) // Test each configured attribute attrs := map[string]string{ "ID Attribute": *settings.IdAttribute, "Username Attribute": *settings.UsernameAttribute, "Email Attribute": *settings.EmailAttribute, "First Name Attr": *settings.FirstNameAttribute, "Last Name Attr": *settings.LastNameAttribute, "Nickname Attribute": *settings.NicknameAttribute, "Position Attribute": *settings.PositionAttribute, } for name, attr := range attrs { if attr == "" { continue } result := model.LdapDiagnosticResult{ TestName: name, TestValue: attr, TotalCount: totalUsers, } count := 0 for _, entry := range sr.Entries { if entry.GetAttributeValue(attr) != "" { count++ } } result.EntriesWithValue = count result.Message = fmt.Sprintf("%d/%d users have this attribute", count, totalUsers) results = append(results, result) } return results } func (ld *LdapDiagnosticImpl) testGroupAttributes(conn *ldapv3.Conn, settings *model.LdapSettings) []model.LdapDiagnosticResult { var results []model.LdapDiagnosticResult groupFilter := *settings.GroupFilter if groupFilter == "" { groupFilter = "(objectClass=group)" } sr, err := conn.Search(ldapv3.NewSearchRequest( *settings.BaseDN, ldapv3.ScopeWholeSubtree, ldapv3.NeverDerefAliases, 100, 0, false, groupFilter, []string{"*"}, nil, )) if err != nil { return []model.LdapDiagnosticResult{{ TestName: "Group Attributes", Error: err.Error(), }} } totalGroups := len(sr.Entries) // Test group attributes attrs := map[string]string{ "Group ID Attribute": *settings.GroupIdAttribute, "Group Display Name Attr": *settings.GroupDisplayNameAttribute, } for name, attr := range attrs { if attr == "" { continue } result := model.LdapDiagnosticResult{ TestName: name, TestValue: attr, TotalCount: totalGroups, } count := 0 for _, entry := range sr.Entries { if entry.GetAttributeValue(attr) != "" { count++ } } result.EntriesWithValue = count result.Message = fmt.Sprintf("%d/%d groups have this attribute", count, totalGroups) results = append(results, result) } return results }