Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,31 @@
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.HttpResponse;
import org.apache.http.NameValuePair;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.message.BasicNameValuePair;
import org.wso2.carbon.base.MultitenantConstants;
import org.wso2.carbon.identity.base.IdentityConstants;
import org.wso2.carbon.identity.core.ThreadLocalAwareThreadPoolExecutor;
import org.wso2.carbon.identity.core.util.IdentityUtil;
import org.wso2.carbon.identity.oauth.common.exception.InvalidOAuthClientException;
import org.wso2.carbon.identity.oauth2.IdentityOAuth2Exception;
import org.wso2.carbon.identity.oidc.session.OIDCSessionConstants;
import org.wso2.carbon.identity.oidc.session.util.OIDCSessionManagementUtil;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.List;
import java.net.URI;
import java.net.URLEncoder;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.net.http.HttpTimeoutException;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLParameters;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;

Expand All @@ -62,11 +57,12 @@ public class LogoutRequestSender {
private static final Log LOG = LogFactory.getLog(LogoutRequestSender.class);
private static LogoutRequestSender instance = null;

private static ExecutorService threadPool = null;
private static ThreadLocalAwareThreadPoolExecutor threadPool = null;
private boolean hostNameVerificationEnabled = true;
private static int httpConnectTimeout = 0;
private static int httpSocketTimeout = 0;
private static final String LOGOUT_TOKEN = "logout_token";
private final HttpClient sharedHttpClient;

private LogoutRequestSender() {

Expand Down Expand Up @@ -118,7 +114,7 @@ private LogoutRequestSender() {
workQueue = new ArrayBlockingQueue<Runnable>(workQueueSizeInt);
}

threadPool = new ThreadPoolExecutor(poolSizeInt, poolSizeInt, keepAliveTimeLong,
threadPool = new ThreadLocalAwareThreadPoolExecutor(poolSizeInt, poolSizeInt, keepAliveTimeLong,
TimeUnit.MILLISECONDS, workQueue);

httpConnectTimeout = Integer.parseInt(httpConnectTimeoutProperty);
Expand All @@ -134,6 +130,14 @@ private LogoutRequestSender() {
", httpSocketTimeout: " + httpSocketTimeout +
", hostNameVerificationEnabled: " + hostNameVerificationEnabled);
}
SSLParameters sslParams = new SSLParameters();
if (!hostNameVerificationEnabled) {
sslParams.setEndpointIdentificationAlgorithm(null); // disable hostname verification
}
sharedHttpClient = HttpClient.newBuilder()
.sslParameters(sslParams)
.connectTimeout(Duration.ofMillis(httpConnectTimeout))
.build();
}

/**
Expand Down Expand Up @@ -203,7 +207,7 @@ public void sendLogoutRequests(String opbsCookieId, String tenantDomain) {
String logoutToken = logoutTokenMap.getKey();
String bcLogoutUrl = logoutTokenMap.getValue();
LOG.debug("A LogoutReqSenderTask will be assigned to the thread pool.");
threadPool.submit(new LogoutReqSenderTask(logoutToken, bcLogoutUrl));
threadPool.submit(new LogoutReqSenderTask(logoutToken, bcLogoutUrl, sharedHttpClient));
}
}
}
Expand Down Expand Up @@ -240,11 +244,13 @@ private class LogoutReqSenderTask implements Runnable {

private String logoutToken;
private String backChannelLogouturl;
private HttpClient httpClient;

public LogoutReqSenderTask(String logoutToken, String backChannelLogouturl) {
public LogoutReqSenderTask(String logoutToken, String backChannelLogouturl, HttpClient httpClient) {

this.logoutToken = logoutToken;
this.backChannelLogouturl = backChannelLogouturl;
this.httpClient = httpClient;
}

@Override
Expand All @@ -254,44 +260,29 @@ public void run() {
LOG.debug("Starting backchannel logout request to: " + backChannelLogouturl);
}

List<NameValuePair> logoutReqParams = new ArrayList<NameValuePair>();
CloseableHttpClient httpClient = null;
try {
if (!hostNameVerificationEnabled) {
httpClient = HttpClients.custom()
.setHostnameVerifier(SSLConnectionSocketFactory.ALLOW_ALL_HOSTNAME_VERIFIER)
.build();
} else {
httpClient = HttpClients.createDefault();
}
logoutReqParams.add(new BasicNameValuePair(LOGOUT_TOKEN, logoutToken));
// Encode form parameters
String formParams = LOGOUT_TOKEN + "=" + URLEncoder.encode(logoutToken, StandardCharsets.UTF_8);

HttpPost httpPost = new HttpPost(backChannelLogouturl);
try {
httpPost.setEntity(new UrlEncodedFormEntity(logoutReqParams));
} catch (UnsupportedEncodingException e) {
LOG.error("Error while encoding logout request parameters.", e);
}
RequestConfig requestConfig = RequestConfig.custom().setConnectTimeout(httpConnectTimeout)
.setSocketTimeout(httpSocketTimeout).build();
httpPost.setConfig(requestConfig);
// Build POST request
HttpRequest request = HttpRequest.newBuilder()
.uri(URI.create(backChannelLogouturl))
.timeout(Duration.ofMillis(httpSocketTimeout))
.header("Content-Type", "application/x-www-form-urlencoded")
.POST(HttpRequest.BodyPublishers.ofString(formParams))
.build();

// Execute request
HttpResponse<String> response = httpClient.send(request, HttpResponse.BodyHandlers.ofString());

HttpResponse response = httpClient.execute(httpPost);
if (LOG.isDebugEnabled()) {
LOG.debug("Backchannel logout response: " + response.getStatusLine());
LOG.debug("Backchannel logout response: " + response.statusCode() + " " + response.body());
}
} catch (SocketTimeoutException e) {
LOG.error("Timeout occurred while sending logout requests to: " + backChannelLogouturl);
} catch (IOException e) {

} catch (HttpTimeoutException e) {
LOG.error("Timeout occurred while sending logout requests to: " + backChannelLogouturl, e);
} catch (IOException | InterruptedException e) {
LOG.error("Error sending logout requests to: " + backChannelLogouturl, e);
} finally {
if (httpClient != null) {
try {
httpClient.close();
} catch (IOException e) {
LOG.error("Error closing http client.", e);
}
}
}
}
}
Expand Down