package saml

import (
	"encoding/json"
	"fmt"
	"net/http"
	"strings"

	"github.com/crewjam/saml"
	"github.com/pkg/errors"
	"github.com/rancher/norman/httperror"
	"github.com/rancher/norman/types"
	apiv3 "github.com/rancher/rancher/pkg/apis/management.cattle.io/v3"
	"github.com/rancher/rancher/pkg/auth/accessor"
	"github.com/rancher/rancher/pkg/auth/api/secrets"
	"github.com/rancher/rancher/pkg/auth/providers/common"
	"github.com/rancher/rancher/pkg/auth/providers/ldap"
	"github.com/rancher/rancher/pkg/auth/tokens"
	client "github.com/rancher/rancher/pkg/client/generated/management/v3"
	publicclient "github.com/rancher/rancher/pkg/client/generated/management/v3public"
	v3 "github.com/rancher/rancher/pkg/generated/norman/management.cattle.io/v3"
	"github.com/rancher/rancher/pkg/types/config"
	"github.com/rancher/rancher/pkg/user"
	wcorev1 "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1"
	"github.com/sirupsen/logrus"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/runtime"
)

const (
	PingName            = "ping"
	ADFSName            = "adfs"
	KeyCloakName        = "keycloak"
	OKTAName            = "okta"
	ShibbolethName      = "shibboleth"
	loginAction         = "login"
	testAndEnableAction = "testAndEnable"
)

type Provider struct {
	authConfigs     v3.AuthConfigInterface
	secrets         wcorev1.SecretController
	samlTokens      v3.SamlTokenInterface
	userMGR         user.Manager
	tokenMGR        *tokens.Manager
	serviceProvider *saml.ServiceProvider
	name            string
	userType        string
	groupType       string
	clientState     ClientState
	ldapProvider    common.AuthProvider
	sloEnabled      bool
	sloForced       bool
}

var SamlProviders = make(map[string]*Provider)

func Configure(mgmtCtx *config.ScaledContext, userMGR user.Manager, tokenMGR *tokens.Manager, name string) common.AuthProvider {
	provider := &Provider{
		authConfigs: mgmtCtx.Management.AuthConfigs(""),
		secrets:     mgmtCtx.Wrangler.Core.Secret(),
		samlTokens:  mgmtCtx.Management.SamlTokens(""),
		userMGR:     userMGR,
		tokenMGR:    tokenMGR,
		name:        name,
		userType:    name + "_user",
		groupType:   name + "_group",
	}

	if provider.hasLdapGroupSearch() {
		provider.ldapProvider = ldap.Configure(mgmtCtx, userMGR, tokenMGR, name)
	}

	SamlProviders[name] = provider
	return provider
}

func (s *Provider) GetName() string {
	return s.name
}

func (s *Provider) CustomizeSchema(schema *types.Schema) {
	schema.ActionHandler = s.actionHandler
	schema.Formatter = s.formatter
}

func (s *Provider) TransformToAuthProvider(authConfig map[string]any) (map[string]any, error) {
	p := common.TransformToAuthProvider(authConfig)
	switch s.name {
	case PingName:
		p[publicclient.PingProviderFieldRedirectURL] = formSamlRedirectURLFromMap(authConfig, s.name)
	case ADFSName:
		p[publicclient.ADFSProviderFieldRedirectURL] = formSamlRedirectURLFromMap(authConfig, s.name)
	case KeyCloakName:
		p[publicclient.KeyCloakProviderFieldRedirectURL] = formSamlRedirectURLFromMap(authConfig, s.name)
	case OKTAName:
		p[publicclient.OKTAProviderFieldRedirectURL] = formSamlRedirectURLFromMap(authConfig, s.name)
	case ShibbolethName:
		p[publicclient.ShibbolethProviderFieldRedirectURL] = formSamlRedirectURLFromMap(authConfig, s.name)
	}
	return p, nil
}

func (s *Provider) AuthenticateUser(http.ResponseWriter, *http.Request, any) (apiv3.Principal, []apiv3.Principal, string, error) {
	return apiv3.Principal{}, nil, "", fmt.Errorf("SAML providers do not implement Authenticate User API")
}

// Logout guards against a regular logout when the system has SLO, i.e. LogoutAll forced.
func (s *Provider) Logout(w http.ResponseWriter, r *http.Request, token accessor.TokenAccessor) error {
	providerName := token.GetAuthProvider()

	logrus.Debugf("SAML [logout]: triggered by provider %s", providerName)

	provider, ok := SamlProviders[providerName]
	if !ok {
		logrus.Debugf("SAML [logout]: Rancher provider resource `%v` not configured at all", providerName)
		return fmt.Errorf("SAML [logout]: Rancher provider resource `%v` not configured at all", providerName)
	}

	if provider.sloForced {
		logrus.Debugf("SAML [logout]: Rancher provider resource `%v` configured for forced SLO, rejecting regular logout", providerName)
		return fmt.Errorf("SAML [logout]: Rancher provider resource `%v` configured for forced SLO, rejecting regular logout", providerName)
	}

	return nil
}

func (s *Provider) LogoutAll(w http.ResponseWriter, r *http.Request, token accessor.TokenAccessor) error {
	providerName := token.GetAuthProvider()

	logrus.Debugf("SAML [logout-all]: triggered by provider %s", providerName)

	provider, ok := SamlProviders[providerName]
	if !ok {
		logrus.Debugf("SAML [logout-all]: Rancher provider resource `%v` not configured at all", providerName)
		return fmt.Errorf("SAML [logout-all]: Rancher provider resource `%v` not configured at all", providerName)
	}

	if !provider.sloEnabled {
		logrus.Debugf("SAML [logout-all]: Rancher provider resource `%v` not configured for SLO", providerName)
		return fmt.Errorf("SAML [logout-all]: Rancher provider resource `%v` not configured for SLO", providerName)
	}

	authLogout := &apiv3.AuthConfigLogoutInput{}

	if err := json.NewDecoder(r.Body).Decode(authLogout); err != nil {
		return httperror.NewAPIError(httperror.InvalidBodyContent,
			fmt.Sprintf("SAML: Failed to parse body: %v", err))
	}

	userName := provider.userMGR.GetUser(r)
	userAttributes, _, err := provider.userMGR.EnsureAndGetUserAttribute(userName)
	if err != nil {
		return err
	}

	usernames := userAttributes.ExtraByProvider[providerName]["username"]
	if len(usernames) == 0 {
		return fmt.Errorf("SAML [logout-all]: UserAttribute extras contains no username for provider %q", providerName)
	}
	userAtProvider := usernames[0]
	finalRedirectURL := authLogout.FinalRedirectURL

	provider.clientState.SetPath(provider.serviceProvider.SloURL.Path)
	provider.clientState.SetState(w, r, "Rancher_FinalRedirectURL", finalRedirectURL)
	provider.clientState.SetState(w, r, "Rancher_Action", "logout-all")

	idpRedirectURL, err := provider.HandleSamlLogout(userAtProvider, w, r)
	if err != nil {
		return err
	}

	logrus.Debugf("SAML [logout-all]: Redirecting to the identity provider logout page at %v", idpRedirectURL)

	data := map[string]any{
		"idpRedirectUrl": idpRedirectURL,
		"type":           "authConfigLogoutOutput",
		"baseType":       "authConfigLogoutOutput",
	}

	w.Header().Set("Content-Type", "application/json")
	return json.NewEncoder(w).Encode(data)
}

func PerformSamlLogin(r *http.Request, w http.ResponseWriter, name string, input any) error {
	// input will contain the FINAL redirect URL
	login, ok := input.(*apiv3.SamlLoginInput)
	if !ok {
		return errors.New("unexpected input type")
	}
	finalRedirectURL := login.FinalRedirectURL

	logrus.Debugf("SAML [PerformSamlLogin]: Id Provider            (%v)", name)

	if provider, ok := SamlProviders[name]; ok {
		if provider == nil {
			logrus.Errorf("SAML: Rancher provider resource %v not initialized", name)
			return fmt.Errorf("SAML: Rancher provider resource %v not initialized", name)
		}
		if provider.clientState == nil {
			logrus.Errorf("SAML: Provider %v clientState not set", name)
			return fmt.Errorf("SAML: Provider %v clientState not set", name)
		}

		provider.clientState.SetPath(provider.serviceProvider.AcsURL.Path)
		provider.clientState.SetState(w, r, "Rancher_FinalRedirectURL", finalRedirectURL)
		provider.clientState.SetState(w, r, "Rancher_Action", loginAction)
		provider.clientState.SetState(w, r, "Rancher_PublicKey", login.PublicKey)
		provider.clientState.SetState(w, r, "Rancher_RequestID", login.RequestID)
		provider.clientState.SetState(w, r, "Rancher_ResponseType", login.ResponseType)

		// userID is not needed for login. It's only needed for testAndEnable
		idpRedirectURL, err := provider.HandleSamlLogin(w, r, "")
		if err != nil {
			return err
		}

		logrus.Debugf("SAML [PerformSamlLogin]: Redirecting to the identity provider login page at %v", idpRedirectURL)
		data := map[string]any{
			"idpRedirectUrl": idpRedirectURL,
			"type":           "samlLoginOutput",
		}

		w.Header().Set("Content-Type", "application/json")
		if err := json.NewEncoder(w).Encode(data); err != nil {
			return fmt.Errorf("SAML: Failed to encode samlLoginOutput: %w", err)
		}

		return nil
	}

	return nil
}

func (s *Provider) getSamlConfig() (*apiv3.SamlConfig, error) {
	authConfigObj, err := s.authConfigs.ObjectClient().UnstructuredClient().Get(s.name, metav1.GetOptions{})
	if err != nil {
		return nil, fmt.Errorf("SAML: failed to retrieve SamlConfig, error: %v", err)
	}

	u, ok := authConfigObj.(runtime.Unstructured)
	if !ok {
		return nil, fmt.Errorf("SAML: failed to retrieve SamlConfig, cannot read k8s Unstructured data")
	}
	storedSamlConfigMap := u.UnstructuredContent()

	storedSamlConfig := &apiv3.SamlConfig{}
	err = common.Decode(storedSamlConfigMap, storedSamlConfig)
	if err != nil {
		return nil, fmt.Errorf("unable to decode Saml Config: %w", err)
	}

	if enabled, ok := storedSamlConfigMap["enabled"].(bool); ok {
		storedSamlConfig.Enabled = enabled
	}

	if storedSamlConfig.SpKey != "" {
		value, err := common.ReadFromSecret(s.secrets, storedSamlConfig.SpKey,
			strings.ToLower(client.PingConfigFieldSpKey))
		if err != nil {
			return nil, err
		}
		storedSamlConfig.SpKey = value
	}

	return storedSamlConfig, nil
}

func (s *Provider) saveSamlConfig(config *apiv3.SamlConfig) error {
	var configType string

	storedSamlConfig, err := s.getSamlConfig()
	if err != nil {
		return err
	}

	switch s.name {
	case PingName:
		configType = client.PingConfigType
	case ADFSName:
		configType = client.ADFSConfigType
	case KeyCloakName:
		configType = client.KeyCloakConfigType
	case OKTAName:
		configType = client.OKTAConfigType
	case ShibbolethName:
		configType = client.ShibbolethConfigType
	}

	config.APIVersion = "management.cattle.io/v3"
	config.Kind = v3.AuthConfigGroupVersionKind.Kind
	config.Type = configType
	storedSamlConfig.Annotations = config.Annotations
	config.ObjectMeta = storedSamlConfig.ObjectMeta

	var field string
	// This assumes the provider needs to create only one secret. If there are new entries
	// in the secret collection, this code that creates the actual secrets would need to be updated.
	if fields, ok := secrets.TypeToFields[configType]; ok && len(fields) > 0 {
		field = strings.ToLower(fields[0])
	}
	spKey, err := common.CreateOrUpdateSecrets(s.secrets, config.SpKey,
		field, strings.ToLower(config.Type))
	if err != nil {
		return err
	}

	config.SpKey = spKey

	if s.hasLdapGroupSearch() {
		combinedConfig, err := s.combineSamlAndLdapConfig(config)
		if err != nil {
			logrus.Warnf("problem combining saml and ldap config, saving partial configuration %s", err.Error())
		}
		_, err = s.authConfigs.ObjectClient().Update(config.ObjectMeta.Name, combinedConfig)
		if err != nil {
			return fmt.Errorf("unable to update authconfig: %w", err)
		}
		return nil
	}

	_, err = s.authConfigs.ObjectClient().Update(config.ObjectMeta.Name, config)
	return err
}

func (s *Provider) toPrincipal(principalType string, princ apiv3.Principal, token accessor.TokenAccessor) apiv3.Principal {
	if principalType == s.userType {
		princ.PrincipalType = common.UserPrincipalType
		if token != nil {
			tokenPrincipal := token.GetUserPrincipal()
			princ.Me = s.isThisUserMe(tokenPrincipal, princ)
			if princ.Me {
				princ.LoginName = tokenPrincipal.LoginName
				princ.DisplayName = tokenPrincipal.DisplayName
			}
		}
	} else {
		princ.PrincipalType = common.GroupPrincipalType
		if token != nil {
			princ.MemberOf = s.userMGR.IsMemberOf(token, princ)
		}
	}

	return princ
}

func (s *Provider) RefetchGroupPrincipals(principalID string, secret string) ([]apiv3.Principal, error) {
	return nil, errors.New("Not implemented")
}

// SearchPrincipals searches for a principal by name using LDAP if configured.
// Otherwise it returns a "fake" principal of a requested type with the name as the searchKey.
// If the principalType is empty, both user and group principals are returned.
// This is done because SAML, in the absence of LDAP, doesn't have a user/group lookup mechanism.
func (s *Provider) SearchPrincipals(searchKey, principalType string, token accessor.TokenAccessor) ([]apiv3.Principal, error) {
	if s.hasLdapGroupSearch() {
		principals, err := s.ldapProvider.SearchPrincipals(searchKey, principalType, token)
		// only give response from ldap if it's configured
		if !ldap.IsNotConfigured(err) {
			return principals, err
		}
	}

	var principals []apiv3.Principal

	if principalType != common.GroupPrincipalType {
		principals = append(principals, apiv3.Principal{
			ObjectMeta:    metav1.ObjectMeta{Name: s.userType + "://" + searchKey},
			DisplayName:   searchKey,
			LoginName:     searchKey,
			PrincipalType: common.UserPrincipalType,
			Provider:      s.name,
		})
	}

	if principalType != common.UserPrincipalType {
		principals = append(principals, apiv3.Principal{
			ObjectMeta:    metav1.ObjectMeta{Name: s.groupType + "://" + searchKey},
			DisplayName:   searchKey,
			LoginName:     searchKey,
			PrincipalType: common.GroupPrincipalType,
			Provider:      s.name,
		})
	}

	return principals, nil
}

func (s *Provider) GetPrincipal(principalID string, token accessor.TokenAccessor) (apiv3.Principal, error) {
	externalID, principalType := splitPrincipalID(principalID)
	if externalID == "" && principalType == "" {
		return apiv3.Principal{}, fmt.Errorf("SAML: invalid id %v", principalID)
	}
	if principalType != s.userType && principalType != s.groupType {
		return apiv3.Principal{}, fmt.Errorf("SAML: Invalid principal type")
	}

	if s.hasLdapGroupSearch() {
		p, err := s.ldapProvider.GetPrincipal(principalID, token)
		// only give response from ldap if it's configured
		if !ldap.IsNotConfigured(err) {
			return p, err
		}
	}

	p := apiv3.Principal{
		ObjectMeta:  metav1.ObjectMeta{Name: principalType + "://" + externalID},
		DisplayName: externalID,
		LoginName:   externalID,
		Provider:    s.name,
	}

	p = s.toPrincipal(principalType, p, token)
	return p, nil
}

func (s *Provider) isThisUserMe(me, other apiv3.Principal) bool {
	return me.ObjectMeta.Name == other.ObjectMeta.Name &&
		me.PrincipalType == other.PrincipalType
}

func (s *Provider) CanAccessWithGroupProviders(userPrincipalID string, groupPrincipals []apiv3.Principal) (bool, error) {
	config, err := s.getSamlConfig()
	if err != nil {
		logrus.Errorf("Error fetching saml config: %v", err)
		return false, err
	}
	allowed, err := s.userMGR.CheckAccess(config.AccessMode, config.AllowedPrincipalIDs, userPrincipalID, groupPrincipals)
	if err != nil {
		return false, err
	}
	return allowed, nil
}

func formSamlRedirectURLFromMap(config map[string]any, name string) string {
	var hostname string
	switch name {
	case PingName:
		hostname, _ = config[client.PingConfigFieldRancherAPIHost].(string)
	case ADFSName:
		hostname, _ = config[client.ADFSConfigFieldRancherAPIHost].(string)
	case KeyCloakName:
		hostname, _ = config[client.KeyCloakConfigFieldRancherAPIHost].(string)
	case OKTAName:
		hostname, _ = config[client.OKTAConfigFieldRancherAPIHost].(string)
	case ShibbolethName:
		hostname, _ = config[client.ShibbolethConfigFieldRancherAPIHost].(string)
	}

	path := hostname + "/v1-saml/" + name + "/login"
	return path
}

func splitPrincipalID(principalID string) (string, string) {
	parts := strings.SplitN(principalID, ":", 2)
	if len(parts) != 2 {
		return "", ""
	}
	externalID := strings.TrimPrefix(parts[1], "//")
	return externalID, parts[0]
}

func (s *Provider) combineSamlAndLdapConfig(config *apiv3.SamlConfig) (runtime.Object, error) {
	// if errors we might not want to turn on ldap
	ldapConfig, _, err := ldap.GetLDAPConfig(s.ldapProvider)

	// can be misconfigured but still want it saved
	if err != nil {
		logrus.Warnf("error pulling %s ldap configs: %s\n", s.name, err)

		// if the the config subkey not in the crd
		if ldapConfig == nil {
			return config, nil
		}

		// only return the saml config on other errors
		// if not configured it might have data in it we want to keep
		if !ldap.IsNotConfigured(err) {
			return config, nil
		}
	}

	var fullConfig runtime.Object
	samlConfig := apiv3.SamlConfig{}
	config.DeepCopyInto(&samlConfig)
	switch s.name {
	case ShibbolethName:
		secretName, err := common.SavePasswordSecret(
			s.secrets,
			ldapConfig.LdapFields.ServiceAccountPassword,
			client.LdapConfigFieldServiceAccountPassword,
			samlConfig.Type,
		)
		if err != nil {
			return config, fmt.Errorf("unable to save ldap service account password: %w", err)
		}

		ldapConfig.LdapFields.ServiceAccountPassword = secretName
		// Set the status for SecretsMigrated to True so it doesn't get re-migrated
		apiv3.AuthConfigConditionSecretsMigrated.SetStatus(&samlConfig, "True")
		fullConfig = &apiv3.ShibbolethConfig{
			SamlConfig:     samlConfig,
			OpenLdapConfig: ldapConfig.LdapFields,
		}
	case OKTAName:
		fullConfig = &apiv3.OKTAConfig{
			SamlConfig:     samlConfig,
			OpenLdapConfig: ldapConfig.LdapFields,
		}
	}

	return fullConfig, nil
}

func (s *Provider) hasLdapGroupSearch() bool {
	return ShibbolethName == s.name || OKTAName == s.name
}

func (s *Provider) GetUserExtraAttributes(userPrincipal apiv3.Principal) map[string][]string {
	return common.GetCommonUserExtraAttributes(userPrincipal)
}

// IsDisabledProvider checks if the SAML auth provider is currently disabled in Rancher.
func (s *Provider) IsDisabledProvider() (bool, error) {
	samlConfig, err := s.getSamlConfig()
	if err != nil {
		return false, err
	}
	return !samlConfig.Enabled, nil
}
