Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/javaTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ concurrency:
jobs:
java_tests:
runs-on: ubuntu-24.04
# Job cap kept above the per-fork surefire timeout (test-forkedProcessTimeout,
# 600s) so surefire can kill a hung fork before GitHub Actions cancels the job.
timeout-minutes: 30
strategy:
fail-fast: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ public static int runCommand(String [] command, String strCurDir, String strOutp
try {
exitValue = process.waitFor();
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
debugPrint(Constants.DEBUG_ERROR, "Program interrunpted: " + ie);
}
debugPrint(Constants.DEBUG_CODE, "Program '" + String.join(" ", command) + "' exited with exit status " + exitValue, strOutputFile);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.DefaultThreadFactory;

@SuppressWarnings("deprecation")
public class FederatedWorker {
Expand Down Expand Up @@ -99,9 +100,11 @@ private void run() {
LOG.info("Setting up Federated Worker on port " + _port);
int par_conn = ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.FEDERATED_PAR_CONN);
final int EVENT_LOOP_THREADS = (par_conn > 0) ? par_conn : InfrastructureAnalyzer.getLocalParallelism();
NioEventLoopGroup bossGroup = new NioEventLoopGroup(1);
// Daemon event loops so a leaked in-JVM (test) worker cannot block JVM exit.
NioEventLoopGroup bossGroup = new NioEventLoopGroup(1,
new DefaultThreadFactory("fed-worker-boss", true));
ThreadPoolExecutor workerTPE = new ThreadPoolExecutor(1, Integer.MAX_VALUE, 10, TimeUnit.SECONDS,
new SynchronousQueue<Runnable>(true));
new SynchronousQueue<Runnable>(true), new DefaultThreadFactory("fed-worker-pool", true));
NioEventLoopGroup workerGroup = new NioEventLoopGroup(EVENT_LOOP_THREADS, workerTPE);

final boolean ssl = ConfigurationManager.isFederatedSSL();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ public void shutdown() {
_q[i].close();
}
}
catch(InterruptedException ignored) {
catch(InterruptedException e) {
Thread.currentThread().interrupt();
}
}
_writeExec.getQueue().clear();
Expand All @@ -174,7 +175,8 @@ public CompletableFuture<Void> scheduleEviction(BlockEntry block) {
int i = (int)(q % WRITER_SIZE);
_q[i].enqueueIfOpen(new Tuple2<>(block, future));
}
catch(InterruptedException ignored) {
catch(InterruptedException e) {
Thread.currentThread().interrupt();
}

return future;
Expand Down
20 changes: 18 additions & 2 deletions src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

Expand Down Expand Up @@ -141,11 +142,26 @@ else if(mainThread || threadName.contains("PARFOR") || threadName.contains("FedE
incorrectPoolUse = true;
}

return Executors.newFixedThreadPool(k);
return Executors.newFixedThreadPool(k, daemonThreadFactory());

}
}

/**
* Thread factory that produces daemon threads. The ForkJoinPool-backed pools already use daemon
* threads; the fallback {@link Executors#newFixedThreadPool} and {@link Executors#newCachedThreadPool}
* pools default to non-daemon threads, which can keep the JVM (e.g. a surefire test fork) alive
* if a caller forgets to shut the pool down. Making them daemon keeps that behavior uniform.
*/
private static ThreadFactory daemonThreadFactory() {
final ThreadFactory base = Executors.defaultThreadFactory();
return r -> {
Thread t = base.newThread(r);
t.setDaemon(true);
return t;
};
}

/**
* Invoke the collection of tasks and shutdown the pool upon job termination.
*
Expand Down Expand Up @@ -180,7 +196,7 @@ public synchronized static ExecutorService getDynamicPool() {
// It is guaranteed not to be shut down because of the synchronized barrier
return asyncPool;
else {
asyncPool = Executors.newCachedThreadPool();
asyncPool = Executors.newCachedThreadPool(daemonThreadFactory());
return asyncPool;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void generate(int N) throws InterruptedException {
}
}
catch(InterruptedException e) {
e.printStackTrace();
Thread.currentThread().interrupt();
}
});
}
Expand Down
6 changes: 6 additions & 0 deletions src/test/java/org/apache/sysds/test/AutomatedTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,9 @@ private static Thread spawnLocalFedWorkerThread(int port, String[] otherArgs) {
LOG.error("Exception in startup of federated worker", e);
}
});
// Daemon so a worker left running by a failed/forgetful test cannot keep the
// surefire fork JVM alive and stall CI until the job-level timeout.
t.setDaemon(true);
t.start();
return t;
}
Expand Down Expand Up @@ -1979,6 +1982,9 @@ public static Thread startLocalFedWorkerWithArgs(String[] args) {
LOG.error("Exception in startup of federated worker on port " + port, e);
}
});
// Daemon so a worker left running by a failed/forgetful test cannot keep the
// surefire fork JVM alive and stall CI until the job-level timeout.
t.setDaemon(true);
t.start();
FederatedWorkerUtils.waitForWorker(t, port, FED_WORKER_WAIT);
return t;
Expand Down
15 changes: 12 additions & 3 deletions src/test/java/org/apache/sysds/test/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -3489,15 +3489,23 @@ public static void shutdownThreads(Process... ts) {
}
}

/** Upper bound (ms) on how long {@link #shutdownThread(Thread)} waits for a worker to stop. */
private static final long THREAD_SHUTDOWN_JOIN_MS = 30_000;

public static void shutdownThread(Thread t) {
// kill the worker
if( t != null ) {
t.interrupt();
try {
t.join();
// Bounded join: workers are daemon threads, so even if one ignores the interrupt
// we must not block cleanup (and the JVM) indefinitely waiting for it.
t.join(THREAD_SHUTDOWN_JOIN_MS);
if( t.isAlive() )
LOG.warn("Federated worker thread " + t.getName()
+ " did not stop within " + THREAD_SHUTDOWN_JOIN_MS + "ms; leaving it as a daemon.");
}
catch (InterruptedException e) {
e.printStackTrace();
Thread.currentThread().interrupt();
}
}
}
Expand All @@ -3514,7 +3522,8 @@ public static void shutdownThread(Process t) {
forciblyDestroyed.waitFor(); // Wait until it's definitely terminated
}
} catch (InterruptedException e) {
e.printStackTrace();
LOG.warn("Interrupted while shutting down federated worker process", e);
Thread.currentThread().interrupt();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,11 @@ public void testBackendPerformance() throws InterruptedException {
taskFutures.forEach(res -> {
try {
Assert.assertEquals("Stats parsed correctly", res.get().statusCode(), 200);
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
Assert.fail("Interrupted while fetching statistics: " + e.getMessage());
} catch (ExecutionException e) {
Assert.fail("Failed to fetch statistics: " + e.getMessage());
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private void runGenericTest(String dmlFile, int scalar) {
compareResults();
}
catch(InterruptedException e) {
e.printStackTrace();
Thread.currentThread().interrupt();
assert (false);
}
finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,50 +105,52 @@ public void federatedReuse(String test) {
Lineage.resetInternalState();
Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT);

TestConfiguration config = availableTestConfigurations.get(test);
loadTestConfiguration(config);

// Run reference dml script with normal matrix. Reuse of ba+*.
fullDMLScriptName = HOME + test + "Reference.dml";
programArgs = new String[] {"-stats", "-lineage", "reuse_full",
"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
"Y2=" + input("Y2"), "Z=" + expected("Z")};
runTest(true, false, null, -1);
long mmCount = Statistics.getCPHeavyHitterCount(Opcodes.MMULT.toString());

// Run actual dml script with federated matrix
// The fed workers reuse ba+*
fullDMLScriptName = HOME + test + ".dml";
programArgs = new String[] {"-stats","-lineage", "reuse_full",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
runTest(true, false, null, -1);
long mmCount_fed = Statistics.getCPHeavyHitterCount(Opcodes.MMULT.toString());
long fedMMCount = Statistics.getCPHeavyHitterCount("fed_ba+*");

// compare results
compareResults(1e-9);
// compare matrix multiplication count
// #federated execution of ba+* = #threads times #non-federated execution of ba+* (after reuse)
Assert.assertTrue("Violated reuse count: "+mmCount_fed+" == "+mmCount*2,
mmCount_fed == mmCount * 2); // #threads = 2
switch(test) {
case TEST_NAME1:
// If the o/p is federated, fed_ba+* will be called everytime
// but the workers should be able to reuse ba+*
assertTrue(fedMMCount > mmCount_fed);
break;
case TEST_NAME2:
// If the o/p is non-federated, fed_ba+* will be called once
// and each worker will call ba+* once.
assertTrue(fedMMCount < mmCount_fed);
break;
try {
TestConfiguration config = availableTestConfigurations.get(test);
loadTestConfiguration(config);

// Run reference dml script with normal matrix. Reuse of ba+*.
fullDMLScriptName = HOME + test + "Reference.dml";
programArgs = new String[] {"-stats", "-lineage", "reuse_full",
"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
"Y2=" + input("Y2"), "Z=" + expected("Z")};
runTest(true, false, null, -1);
long mmCount = Statistics.getCPHeavyHitterCount(Opcodes.MMULT.toString());

// Run actual dml script with federated matrix
// The fed workers reuse ba+*
fullDMLScriptName = HOME + test + ".dml";
programArgs = new String[] {"-stats","-lineage", "reuse_full",
"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
runTest(true, false, null, -1);
long mmCount_fed = Statistics.getCPHeavyHitterCount(Opcodes.MMULT.toString());
long fedMMCount = Statistics.getCPHeavyHitterCount("fed_ba+*");

// compare results
compareResults(1e-9);
// compare matrix multiplication count
// #federated execution of ba+* = #threads times #non-federated execution of ba+* (after reuse)
Assert.assertTrue("Violated reuse count: "+mmCount_fed+" == "+mmCount*2,
mmCount_fed == mmCount * 2); // #threads = 2
switch(test) {
case TEST_NAME1:
// If the o/p is federated, fed_ba+* will be called everytime
// but the workers should be able to reuse ba+*
assertTrue(fedMMCount > mmCount_fed);
break;
case TEST_NAME2:
// If the o/p is non-federated, fed_ba+* will be called once
// and each worker will call ba+* once.
assertTrue(fedMMCount < mmCount_fed);
break;
}
}
finally {
TestUtils.shutdownThreads(workers);
}


TestUtils.shutdownThreads(workers);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -110,43 +110,45 @@ private void runTriUDFReuse(ExecMode execMode) {
Lineage.resetInternalState();
Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2, port3, port4}, otherargs, FED_WORKER_WAIT);

rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
System.out.println(7);
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
try {
rtplatform = execMode;
if(rtplatform == ExecMode.SPARK) {
System.out.println(7);
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);

// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-lineage", "reuse_full", "-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
runTest(null);

// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-lineage", "reuse_full", "-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
"rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};

runTest(null);

// compare via files
compareResults(1e-9);
// check if lowertri is federated
Assert.assertTrue(heavyHittersContainsString("fed_lowertri"));
// assert reuse count
Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);
}
finally {
TestUtils.shutdownThreads(workers);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);

// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-lineage", "reuse_full", "-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
runTest(null);

// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-lineage", "reuse_full", "-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
"rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};

runTest(null);

// compare via files
compareResults(1e-9);
// check if lowertri is federated
Assert.assertTrue(heavyHittersContainsString("fed_lowertri"));
// assert reuse count
Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);

TestUtils.shutdownThreads(workers);

rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;

Thread[] workers = null;
try {
// generated lm data
MatrixBlock X = MatrixBlock.randOperations(rows, cols, 1.0, 0, 1, "uniform", 7);
Expand All @@ -93,7 +94,7 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE
int port3 = getRandomAvailablePort();
int port4 = getRandomAvailablePort();
String[] otherargs = new String[] {"-lineage", "reuse_full"};
Thread[] workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT);
workers = startLocalFedWorkerThreads(new int[] {port1, port2}, otherargs, FED_WORKER_WAIT);

TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
Expand Down Expand Up @@ -134,10 +135,9 @@ public void federatedLmPipeline(ExecMode execMode, boolean contSplits, String TE
assertTrue(fed_tsmmCount > fed_tsmmCount_reuse);
assertTrue(mmCount > mmCount_reuse);
assertTrue(fed_mmCount > fed_mmCount_reuse);

TestUtils.shutdownThreads(workers);
}
finally {
TestUtils.shutdownThreads(workers);
resetExecMode(oldExec);
ColumnEncoderRecode.SORT_RECODE_MAP = oldSort;
}
Expand Down
Loading
Loading