Skip to content

Monitor keepalives with fewer threads #987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/itest/java/com/hierynomus/sshj/KeepAliveTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.hierynomus.sshj;

import net.schmizz.keepalive.BoundedKeepAliveProvider;
import net.schmizz.sshj.Config;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.LoggerFactory;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.testcontainers.junit.jupiter.Container;

import java.util.ArrayList;
import java.util.List;

public class KeepAliveTest {
@Container
SshdContainer sshd = new SshdContainer(SshdContainer.Builder
.defaultBuilder()
.withAllKeys()
.withPackages("iptables")
.withPrivileged(true));

@Test
void testKeepAlive() throws Exception {
sshd.start();

Config config = new DefaultConfig();
BoundedKeepAliveProvider p = new BoundedKeepAliveProvider(LoggerFactory.DEFAULT, 4);
p.setKeepAliveInterval(1);
p.setMaxKeepAliveCount(1);
config.setKeepAliveProvider(p);
List<SSHClient> clients = new ArrayList<>();
for (int i=0; i<10; i++) {
SSHClient c = new SSHClient(config);
c.addHostKeyVerifier(new PromiscuousVerifier());
c.connect("127.0.0.1", sshd.getFirstMappedPort());
c.authPassword("sshj", "ultrapassword");
var sess = c.startSession();
sess.allocateDefaultPTY();
clients.add(c);
}

for (SSHClient client : clients) {
Assertions.assertTrue(client.isConnected());
}

var res = sshd.execInContainer("iptables", "-A", "INPUT", "-p", "tcp", "--dport", "22", "-j", "DROP");
Assertions.assertEquals(0, res.getExitCode());
// wait for keepalive to take action
Thread.sleep(2000);

for (SSHClient client : clients) {
Assertions.assertFalse(client.isConnected());
}

p.shutdown();
}
}
19 changes: 18 additions & 1 deletion src/itest/java/com/hierynomus/sshj/SshdContainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,24 @@ public static class Builder implements Consumer<DockerfileBuilder> {
private List<String> hostKeys = new ArrayList<>();
private List<String> certificates = new ArrayList<>();
private @NotNull SshdConfigBuilder sshdConfig = SshdConfigBuilder.defaultBuilder();
private boolean privileged = false;
private List<String> packages = new ArrayList<>();

public static Builder defaultBuilder() {
Builder b = new Builder();

return b;
}

public @NotNull Builder withPrivileged(boolean privileged) {
this.privileged = privileged;
return this;
}

public @NotNull Builder withPackages(@NotNull String... packages) {
this.packages.addAll(List.of(packages));
return this;
}


public @NotNull Builder withSshdConfig(@NotNull SshdConfigBuilder sshdConfig) {
this.sshdConfig = sshdConfig;
Expand Down Expand Up @@ -153,6 +164,9 @@ public void accept(@NotNull DockerfileBuilder builder) {
builder.expose(22);
builder.copy("entrypoint.sh", "/entrypoint.sh");

if (!packages.isEmpty()) {
builder.run("apk add --no-cache " + String.join(" ", packages));
}
builder.add("authorized_keys", "/home/sshj/.ssh/authorized_keys");
builder.copy("test-container/trusted_ca_keys", "/etc/ssh/trusted_ca_keys");

Expand Down Expand Up @@ -201,6 +215,9 @@ public SshdContainer() {

public SshdContainer(SshdContainer.Builder builder) {
this(builder.buildInner());
if (builder.privileged) {
withPrivilegedMode(true);
}
}

public SshdContainer(@NotNull Future<String> future) {
Expand Down
201 changes: 201 additions & 0 deletions src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package net.schmizz.keepalive;

import net.schmizz.sshj.Config;
import net.schmizz.sshj.common.LoggerFactory;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.PriorityBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
* This implementation manages all {@link KeepAlive}s using configured number of threads. It works like a
* thread pool, thus {@link BoundedKeepAliveProvider#shutdown()} must be called to clean up resources.
* <br>
* This provider uses {@link KeepAliveRunner#doKeepAlive()} as delegate, so it supports maxKeepAliveCount
* parameter. All instances provided by this provider have identical configuration.
*/
public class BoundedKeepAliveProvider extends KeepAliveProvider {

public int maxKeepAliveCount = 3;
public int keepAliveInterval = 5;

protected final KeepAliveMonitor monitor;


public BoundedKeepAliveProvider(LoggerFactory loggerFactory, int numberOfThreads) {
this.monitor = new KeepAliveMonitor(loggerFactory, numberOfThreads);
}

public void setKeepAliveInterval(int interval) {
keepAliveInterval = interval;
}

public void setMaxKeepAliveCount(int count) {
maxKeepAliveCount = count;
}

@Override
public KeepAlive provide(ConnectionImpl connection) {
return new Impl(connection, "bounded-keepalive-impl");
}

public void shutdown() throws InterruptedException {
monitor.shutdown();
}

class Impl extends KeepAlive {

private final KeepAliveRunner delegate;

protected Impl(ConnectionImpl conn, String name) {
super(conn, name);
this.delegate = new KeepAliveRunner(conn);

// take care here, some parameters are set to both delegate and this
this.delegate.setMaxAliveCount(BoundedKeepAliveProvider.this.maxKeepAliveCount);
super.keepAliveInterval = BoundedKeepAliveProvider.this.keepAliveInterval;
}

@Override
protected void doKeepAlive() throws TransportException, ConnectionException {
delegate.doKeepAlive();
}

@Override
public void startKeepAlive() {
monitor.register(this);
}

}

protected static class KeepAliveMonitor {
private final Logger logger;

private final PriorityBlockingQueue<Wrapper> q =
new PriorityBlockingQueue<>(32, Comparator.comparingLong(w -> w.nextTimeMillis));
private static final List<Thread> workerThreads = new ArrayList<>();

private volatile long idleSleepMillis = 100;
private final int numberOfThreads;

volatile boolean started = false;

private final ReentrantLock lock = new ReentrantLock();
private final Condition shutDown = lock.newCondition();
private final AtomicInteger shutDownCnt = new AtomicInteger(0);

public KeepAliveMonitor(LoggerFactory loggerFactory, int numberOfThreads) {
this.numberOfThreads = numberOfThreads;
logger = loggerFactory.getLogger(KeepAliveMonitor.class);
}

// made public for test
public void register(KeepAlive keepAlive) {
if (!started) {
start();
}
q.add(new Wrapper(keepAlive));
}

public void setIdleSleepMillis(long idleSleepMillis) {
this.idleSleepMillis = idleSleepMillis;
}

private void sleep() {
sleep(idleSleepMillis);
}

private void sleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}

private synchronized void start() {
if (started) {
return;
}

for (int i = 0; i < numberOfThreads; i++) {
Thread t = new Thread(this::doStart);
workerThreads.add(t);
}
workerThreads.forEach(Thread::start);
started = true;
}


private void doStart() {
while (!Thread.currentThread().isInterrupted()) {
Wrapper wrapper;

if (q.isEmpty() || (wrapper = q.poll()) == null) {
sleep();
continue;
}

long currentTimeMillis = System.currentTimeMillis();
if (wrapper.nextTimeMillis > currentTimeMillis) {
long sleepMillis = wrapper.nextTimeMillis - currentTimeMillis;
logger.debug("{} millis until next check, sleep", sleepMillis);
sleep(sleepMillis);
}

try {
wrapper.keepAlive.doKeepAlive();
q.add(wrapper.reschedule());
} catch (Exception e) {
// If we weren't interrupted, kill the transport, then this exception was unexpected.
// Else we're in shutdown-mode already, so don't forcibly kill the transport.
if (!Thread.currentThread().isInterrupted()) {
wrapper.keepAlive.conn.getTransport().die(e);
}
}
}
lock.lock();
try {
if (shutDownCnt.incrementAndGet() == numberOfThreads) {
shutDown.signal();
}
} finally {
lock.unlock();
}
}

private synchronized void shutdown() throws InterruptedException {
if (workerThreads.isEmpty()) {
return;
}
for (Thread t : workerThreads) {
t.interrupt();
}
lock.lock();
logger.info("waiting for all {} threads to finish", numberOfThreads);
shutDown.await();
}

private static class Wrapper {
private final KeepAlive keepAlive;
private final long nextTimeMillis;

private Wrapper(KeepAlive keepAlive) {
this.keepAlive = keepAlive;
this.nextTimeMillis = System.currentTimeMillis() + keepAlive.keepAliveInterval * 1000L;
}

private Wrapper reschedule() {
return new Wrapper(keepAlive);
}
}
}
}
7 changes: 7 additions & 0 deletions src/main/java/net/schmizz/keepalive/KeepAlive.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,11 @@ public void run() {
}

protected abstract void doKeepAlive() throws TransportException, ConnectionException;

/**
* Start keep-alive loop. Implementations MUST NOT block current thread.
*/
public void startKeepAlive() {
start();
}
}
2 changes: 1 addition & 1 deletion src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ protected void onConnect()
final KeepAlive keepAliveThread = conn.getKeepAlive();
if (keepAliveThread.isEnabled()) {
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start();
keepAliveThread.startKeepAlive();
}
}

Expand Down
Loading
Loading