package tigase.auth.impl;

import java.io.IOException;
import java.security.MessageDigest;
import java.security.cert.X509Certificate;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.SaslException;
import tigase.auth.AuthRepositoryAware;
import tigase.auth.CallbackHandlerFactory;
import tigase.auth.DomainAware;
import tigase.auth.MechanismNameAware;
import tigase.auth.SessionAware;
import tigase.auth.XmppSaslException;
import tigase.auth.callbacks.AuthorizationIdCallback;
import tigase.auth.callbacks.ChannelBindingCallback;
import tigase.auth.callbacks.PBKDIterationsCallback;
import tigase.auth.callbacks.SaltCallback;
import tigase.auth.callbacks.ServerKeyCallback;
import tigase.auth.callbacks.StoredKeyCallback;
import tigase.auth.callbacks.XMPPSessionCallback;
import tigase.auth.credentials.Credentials;
import tigase.auth.credentials.entries.PlainCredentialsEntry;
import tigase.auth.credentials.entries.ScramCredentialsEntry;
import tigase.auth.mechanisms.AbstractSasl;
import tigase.auth.mechanisms.AbstractSaslSCRAM;
import tigase.db.AuthRepository;
import tigase.util.Base64;
import tigase.xmpp.XMPPResourceConnection;
import tigase.xmpp.jid.BareJID;

/* loaded from: input_file:tigase/auth/impl/ScramCallbackHandler.class */
public class ScramCallbackHandler implements CallbackHandler, AuthRepositoryAware, SessionAware, DomainAware, MechanismNameAware {
    private static final Logger log = Logger.getLogger(ScramCallbackHandler.class.getCanonicalName());
    private ScramCredentialsEntry credentialsEntry;
    private boolean credentialsFetched;
    private String domain;
    private String mechanismName;
    private AuthRepository repo;
    private XMPPResourceConnection session;
    private String credentialId = null;
    private BareJID jid = null;
    private boolean loggingInForbidden = false;

    @Override // javax.security.auth.callback.CallbackHandler
    public void handle(Callback[] callbackArr) throws IOException, UnsupportedCallbackException {
        for (int i = 0; i < callbackArr.length; i++) {
            if (log.isLoggable(Level.FINEST)) {
                log.log(Level.FINEST, "Callback: {0}", callbackArr[i].getClass().getSimpleName());
            }
            handleCallback(callbackArr[i]);
        }
    }

    @Override // tigase.auth.AuthRepositoryAware
    public void setAuthRepository(AuthRepository authRepository) {
        this.repo = authRepository;
    }

    @Override // tigase.auth.DomainAware
    public void setDomain(String str) {
        this.domain = str;
    }

    @Override // tigase.auth.MechanismNameAware
    public void setMechanismName(String str) {
        this.mechanismName = str;
    }

    @Override // tigase.auth.SessionAware
    public void setSession(XMPPResourceConnection xMPPResourceConnection) {
        this.session = xMPPResourceConnection;
    }

    protected void handleAuthorizeCallback(AuthorizeCallback authorizeCallback) throws SaslException {
        String authenticationID = authorizeCallback.getAuthenticationID();
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "AuthorizeCallback: authenId: {0}", authenticationID);
        }
        fetchCredentials();
        if (this.loggingInForbidden) {
            authorizeCallback.setAuthorized(false);
            if (log.isLoggable(Level.FINEST)) {
                log.log(Level.FINEST, "User {0} is disabled", this.jid);
                return;
            }
            return;
        }
        String authorizationID = authorizeCallback.getAuthorizationID();
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "AuthorizeCallback: authorId: {0}", authorizationID);
        }
        authorizeCallback.setAuthorized(true);
        this.session.removeSessionData(CallbackHandlerFactory.AUTH_JID);
    }

    protected void handleCallback(Callback callback) throws UnsupportedCallbackException, IOException {
        if (callback instanceof XMPPSessionCallback) {
            ((XMPPSessionCallback) callback).setSession(this.session);
            return;
        }
        if (callback instanceof ChannelBindingCallback) {
            handleChannelBindingCallback((ChannelBindingCallback) callback);
            return;
        }
        if (callback instanceof PBKDIterationsCallback) {
            handlePBKDIterationsCallback((PBKDIterationsCallback) callback);
            return;
        }
        if (callback instanceof NameCallback) {
            handleNameCallback((NameCallback) callback);
            return;
        }
        if (callback instanceof AuthorizationIdCallback) {
            handleAuthorizationIdCallback((AuthorizationIdCallback) callback);
            return;
        }
        if (callback instanceof SaltCallback) {
            handleSaltCallback((SaltCallback) callback);
            return;
        }
        if (callback instanceof ServerKeyCallback) {
            handleServerKeyCallback((ServerKeyCallback) callback);
        } else if (callback instanceof StoredKeyCallback) {
            handleStoredKeyCallback((StoredKeyCallback) callback);
        } else {
            if (!(callback instanceof AuthorizeCallback)) {
                throw new UnsupportedCallbackException(callback, "Unrecognized Callback " + String.valueOf(callback));
            }
            handleAuthorizeCallback((AuthorizeCallback) callback);
        }
    }

    protected void handleNameCallback(NameCallback nameCallback) throws IOException {
        this.credentialId = "default";
        BareJID bareJIDInstanceNS = BareJID.bareJIDInstanceNS(nameCallback.getDefaultName());
        if (bareJIDInstanceNS.getLocalpart() == null || !this.domain.equalsIgnoreCase(bareJIDInstanceNS.getDomain())) {
            bareJIDInstanceNS = BareJID.bareJIDInstanceNS(nameCallback.getDefaultName(), this.domain);
        }
        setJid(bareJIDInstanceNS);
        nameCallback.setName(bareJIDInstanceNS.toString());
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "NameCallback: {0}", this.credentialId);
        }
    }

    protected void handlePBKDIterationsCallback(PBKDIterationsCallback pBKDIterationsCallback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "PBKDIterationsCallback: {0}", this.jid);
        }
        fetchCredentials();
        if (this.credentialsEntry != null) {
            pBKDIterationsCallback.setInterations(this.credentialsEntry.getIterations());
        }
    }

    protected void handleSaltCallback(SaltCallback saltCallback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "SaltCallback: {0}", this.jid);
        }
        fetchCredentials();
        if (this.credentialsEntry != null) {
            saltCallback.setSalt(this.credentialsEntry.getSalt());
        } else {
            saltCallback.setSalt(null);
        }
    }

    private void fetchCredentials() throws SaslException {
        if (this.credentialsFetched) {
            return;
        }
        try {
            Credentials credentials = this.repo.getCredentials(this.jid, this.credentialId);
            log.log(Level.FINE, "Fetched credentials for: " + String.valueOf(this.jid) + " with credentialsId: " + this.credentialId + ", credentials: " + String.valueOf(credentials));
            if (credentials == null) {
                this.loggingInForbidden = true;
            } else {
                String substring = this.mechanismName.endsWith("-PLUS") ? this.mechanismName.substring(0, this.mechanismName.length() - "-PLUS".length()) : this.mechanismName;
                Credentials.Entry entryForMechanism = credentials.getEntryForMechanism(substring);
                if (entryForMechanism == null) {
                    entryForMechanism = credentials.getEntryForMechanism("PLAIN");
                }
                if (entryForMechanism instanceof ScramCredentialsEntry) {
                    this.credentialsEntry = (ScramCredentialsEntry) entryForMechanism;
                } else if (entryForMechanism instanceof PlainCredentialsEntry) {
                    this.credentialsEntry = new ScramCredentialsEntry(substring.replace("SCRAM-", ""), (PlainCredentialsEntry) entryForMechanism);
                }
                this.loggingInForbidden = !credentials.canLogin();
                if (this.loggingInForbidden) {
                    throw XmppSaslException.getExceptionFor(credentials.getAccountStatus());
                }
            }
        } catch (Exception e) {
            log.log(Level.FINE, "Could not retrieve credentials for user " + String.valueOf(this.jid) + " with credentialId " + this.credentialId, (Throwable) e);
        } catch (SaslException e2) {
            throw e2;
        }
        this.credentialsFetched = true;
    }

    private void handleAuthorizationIdCallback(AuthorizationIdCallback authorizationIdCallback) {
        if (AbstractSasl.isAuthzIDIgnored() || authorizationIdCallback.getAuthzId() == null || authorizationIdCallback.getAuthzId().equals(this.jid.toString())) {
            this.credentialId = "default";
            authorizationIdCallback.setAuthzId(this.jid.toString());
        } else {
            try {
                this.credentialId = this.jid.getLocalpart();
                setJid(BareJID.bareJIDInstance(authorizationIdCallback.getAuthzId()));
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    private void handleChannelBindingCallback(ChannelBindingCallback channelBindingCallback) {
        if (channelBindingCallback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_exporter) {
            channelBindingCallback.setBindingData((byte[]) this.session.getSessionData(AbstractSaslSCRAM.TLS_EXPORTER_KEY));
        } else if (channelBindingCallback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_unique) {
            channelBindingCallback.setBindingData((byte[]) this.session.getSessionData(AbstractSaslSCRAM.TLS_UNIQUE_ID_KEY));
        } else if (channelBindingCallback.getRequestedBindType() == AbstractSaslSCRAM.BindType.tls_server_end_point) {
            try {
                X509Certificate x509Certificate = (X509Certificate) this.session.getSessionData(AbstractSaslSCRAM.LOCAL_CERTIFICATE_KEY);
                String sigAlgName = x509Certificate.getSigAlgName();
                int indexOf = sigAlgName.indexOf("with");
                if (indexOf <= 0) {
                    throw new RuntimeException("Unable to parse SigAlgName: " + sigAlgName);
                }
                String substring = sigAlgName.substring(0, indexOf);
                if (substring.equalsIgnoreCase("MD5") || substring.equalsIgnoreCase("SHA1")) {
                    substring = "SHA-256";
                }
                MessageDigest messageDigest = MessageDigest.getInstance(substring);
                messageDigest.update(x509Certificate.getEncoded());
                channelBindingCallback.setBindingData(messageDigest.digest());
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (log.isLoggable(Level.FINEST)) {
            Logger logger = log;
            Level level = Level.FINEST;
            Object[] objArr = new Object[3];
            objArr[0] = channelBindingCallback.getRequestedBindType();
            objArr[1] = channelBindingCallback.getBindingData() == null ? "null" : Base64.encode(channelBindingCallback.getBindingData());
            objArr[2] = this.session;
            logger.log(level, "Channel binding {0}: {1} in session-id {2}", objArr);
        }
    }

    private void handleServerKeyCallback(ServerKeyCallback serverKeyCallback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "ServerKeyCallback: {0}", this.jid);
        }
        fetchCredentials();
        if (this.credentialsEntry != null) {
            serverKeyCallback.setServerKey(this.credentialsEntry.getServerKey());
        } else {
            serverKeyCallback.setServerKey(null);
        }
    }

    private void handleStoredKeyCallback(StoredKeyCallback storedKeyCallback) throws SaslException {
        if (log.isLoggable(Level.FINEST)) {
            log.log(Level.FINEST, "StoredKeyCallback: {0}", this.jid);
        }
        fetchCredentials();
        if (this.credentialsEntry != null) {
            storedKeyCallback.setStoredKey(this.credentialsEntry.getStoredKey());
        } else {
            storedKeyCallback.setStoredKey(null);
        }
    }

    private void setJid(BareJID bareJID) {
        this.jid = bareJID;
        if (bareJID != null) {
            this.session.putSessionData(CallbackHandlerFactory.AUTH_JID, bareJID);
        }
    }
}
