diff --git a/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintRef.java b/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintRef.java index 6dac6edb8b2..9122f07fc3f 100644 --- a/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintRef.java +++ b/core/src/main/java/ai/timefold/solver/core/api/score/stream/ConstraintRef.java @@ -29,4 +29,9 @@ public int compareTo(ConstraintRef other) { return id.compareTo(other.id); } + @Override + public String toString() { + return id; + } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractBavetNodeNetwork.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractBavetNodeNetwork.java new file mode 100644 index 00000000000..fb860e957f1 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractBavetNodeNetwork.java @@ -0,0 +1,174 @@ +package ai.timefold.solver.core.impl.bavet; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Function; +import java.util.stream.Stream; + +import ai.timefold.solver.core.impl.bavet.common.AbstractNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractTwoInputNode; +import ai.timefold.solver.core.impl.bavet.common.Propagator; +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; + +/** + * Represents Bavet's network of nodes, specific to a particular session. + */ +@NullMarked +public abstract class AbstractBavetNodeNetwork { + + protected static AbstractNode[][] buildLayeredNodes(List nodeList) { + var layerMap = new TreeMap>(); + nodeList.forEach(node -> layerMap.computeIfAbsent(node.getLayerIndex(), unused -> new ArrayList<>()).add(node)); + var layerCount = layerMap.size(); + var layeredNodes = new AbstractNode[layerCount][]; + for (var i = 0; i < layerCount; i++) { + var layer = layerMap.get((long) i); + layeredNodes[i] = layer.toArray(new AbstractNode[0]); + } + return layeredNodes; + } + + private final Map, List>> declaredClassToNodeMap; + + private final AbstractNode[][] layeredNodes; + private final Function propagatorFunction; + /** + * A subset of {@code layeredNodes}. + * Once non-null, only contains propagators of nodes which are active. + * See {@link ActivitySupport#isActive()} for details. + */ + private Propagator @Nullable [][] layeredActivePropagators; + /** + * For testing only: the set of nodes that remained active after {@link #settle()}; null before settle. + */ + private @Nullable Set activeNodeSet; + + /** + * @param declaredClassToNodeMap starting nodes, one for each class used in the constraints; + * root nodes, layer index 0. + * @param layeredNodes nodes grouped first by their layer, then by their index within the layer; + * propagation needs to happen in this order. + */ + protected AbstractBavetNodeNetwork(Map, List>> declaredClassToNodeMap, + AbstractNode[][] layeredNodes, Function propagatorFunction) { + this.declaredClassToNodeMap = declaredClassToNodeMap; + this.layeredNodes = layeredNodes; + this.propagatorFunction = propagatorFunction; + } + + public int forEachNodeCount() { + return declaredClassToNodeMap.size(); + } + + /** + * + * @param factClass + * @return if {@link #isActivationCheckComplete()} is true, only returns active root nodes; + * otherwise returns all root nodes. + * This means that if this information was ever read before activation checks were complete, + * it should be re-read after to make sure no inactive nodes are included. + */ + public Stream> getRootNodesAcceptingType(Class factClass) { + return declaredClassToNodeMap.entrySet().stream() + .flatMap(entry -> entry.getValue().stream()) + .filter(tupleSourceRoot -> tupleSourceRoot.allowsInstancesOf(factClass)) + .filter(node -> !isActivationCheckComplete() || activeNodeSet.contains(node)); + } + + public void settle() { + if (layeredActivePropagators == null) { + // Remove inactive nodes and settle the layers in one go. + var initializedRootNodes = Collections.newSetFromMap(new IdentityHashMap<>()); + declaredClassToNodeMap.forEach((declaredClass, rootNodes) -> rootNodes.forEach(rootNode -> { + if (initializedRootNodes.add(rootNode)) { + // Ensure one initialization per node. + // Root nodes are filled from a session, which can always produce. + rootNode.afterAllFactsInserted(true); + } + })); + + var activeNodes = Collections. newSetFromMap(new IdentityHashMap<>()); + layeredActivePropagators = Arrays.stream(layeredNodes) + .map(layer -> Arrays.stream(layer) + .filter(s -> switch (s) { + case ActivitySupport activityEnabled -> activityEnabled.isActive(); + case AbstractTwoInputNode twoInputNode -> twoInputNode.isActive(); + }) + .peek(activeNodes::add) + .map(propagatorFunction).toArray(Propagator[]::new)) + .filter(layer -> layer.length > 0).peek(AbstractBavetNodeNetwork::settleLayer).toArray(Propagator[][]::new); + this.activeNodeSet = activeNodes; + return; + } + // Simplified loop when the layers were already trimmed. + for (var layer : layeredActivePropagators) { + settleLayer(layer); + } + } + + public boolean isActivationCheckComplete() { + return layeredActivePropagators != null; + } + + Set getActiveNodes() { + if (activeNodeSet == null) { + throw new IllegalStateException("Impossible state: getActiveNodes() called before settle()."); + } + return activeNodeSet; + } + + /** + * For testing only. All nodes in the network, regardless of activity. + */ + List getNodes() { + return Arrays.stream(layeredNodes).flatMap(Arrays::stream).toList(); + } + + private static void settleLayer(Propagator[] nodesInLayer) { + if (nodesInLayer.length == 1) { // Avoid iteration. + nodesInLayer[0].propagateEverything(); + } else { + for (var node : nodesInLayer) { + node.propagateRetracts(); + } + for (var node : nodesInLayer) { + node.propagateUpdates(); + } + for (var node : nodesInLayer) { + node.propagateInserts(); + } + } + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof AbstractBavetNodeNetwork that)) + return false; + return Objects.equals(declaredClassToNodeMap, that.declaredClassToNodeMap) + && Objects.deepEquals(layeredNodes, that.layeredNodes); + } + + @Override + public int hashCode() { + return Objects.hash(declaredClassToNodeMap, Arrays.deepHashCode(layeredNodes)); + } + + @Override + public String toString() { + return "%s with %d forEach nodes.".formatted(getClass().getSimpleName(), forEachNodeCount()); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java index f97d42bffcb..e2a52346fa9 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/AbstractSession.java @@ -1,36 +1,38 @@ package ai.timefold.solver.core.impl.bavet; -import java.util.IdentityHashMap; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode.LifecycleOperation; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode.LifecycleOperation; -public abstract class AbstractSession { +public abstract class AbstractSession { - private final NodeNetwork nodeNetwork; - private final Map, BavetRootNode[]> insertEffectiveClassToNodeArrayMap; - private final Map, BavetRootNode[]> updateEffectiveClassToNodeArrayMap; - private final Map, BavetRootNode[]> retractEffectiveClassToNodeArrayMap; - private boolean settled = true; + protected final Network_ nodeNetwork; + private final Map, AbstractRootNode[]> insertEffectiveClassToNodeArrayMap; + private final Map, AbstractRootNode[]> updateEffectiveClassToNodeArrayMap; + private final Map, AbstractRootNode[]> retractEffectiveClassToNodeArrayMap; + private boolean initialized = false; + private boolean settled = false; - protected AbstractSession(NodeNetwork nodeNetwork) { + protected AbstractSession(Network_ nodeNetwork) { this.nodeNetwork = nodeNetwork; - this.insertEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount()); - this.updateEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount()); - this.retractEffectiveClassToNodeArrayMap = new IdentityHashMap<>(nodeNetwork.forEachNodeCount()); + this.insertEffectiveClassToNodeArrayMap = HashMap.newHashMap(nodeNetwork.forEachNodeCount()); + this.updateEffectiveClassToNodeArrayMap = HashMap.newHashMap(nodeNetwork.forEachNodeCount()); + this.retractEffectiveClassToNodeArrayMap = HashMap.newHashMap(nodeNetwork.forEachNodeCount()); } public final void insert(Object fact) { settled = false; var factClass = fact.getClass(); - for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.INSERT)) { + for (var node : findNodes(factClass, AbstractRootNode.LifecycleOperation.INSERT)) { node.insert(fact); } } @SuppressWarnings("unchecked") - private BavetRootNode[] findNodes(Class factClass, LifecycleOperation lifecycleOperation) { + private AbstractRootNode[] findNodes(Class factClass, LifecycleOperation lifecycleOperation) { var effectiveClassToNodeArrayMap = switch (lifecycleOperation) { case INSERT -> insertEffectiveClassToNodeArrayMap; case UPDATE -> updateEffectiveClassToNodeArrayMap; @@ -41,7 +43,7 @@ private BavetRootNode[] findNodes(Class factClass, LifecycleOperation if (nodeArray == null) { nodeArray = nodeNetwork.getRootNodesAcceptingType(factClass) .filter(node -> node.supports(lifecycleOperation)) - .toArray(BavetRootNode[]::new); + .toArray(AbstractRootNode[]::new); effectiveClassToNodeArrayMap.put(factClass, nodeArray); } return nodeArray; @@ -50,7 +52,7 @@ private BavetRootNode[] findNodes(Class factClass, LifecycleOperation public final void update(Object fact) { settled = false; var factClass = fact.getClass(); - for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.UPDATE)) { + for (var node : findNodes(factClass, AbstractRootNode.LifecycleOperation.UPDATE)) { node.update(fact); } } @@ -58,7 +60,7 @@ public final void update(Object fact) { public final void retract(Object fact) { settled = false; var factClass = fact.getClass(); - for (var node : findNodes(factClass, BavetRootNode.LifecycleOperation.RETRACT)) { + for (var node : findNodes(factClass, AbstractRootNode.LifecycleOperation.RETRACT)) { node.retract(fact); } } @@ -68,11 +70,23 @@ public final void settle() { return; } nodeNetwork.settle(); + if (!initialized && nodeNetwork.isActivationCheckComplete()) { + removeInactiveRootNodes(insertEffectiveClassToNodeArrayMap); + removeInactiveRootNodes(updateEffectiveClassToNodeArrayMap); + removeInactiveRootNodes(retractEffectiveClassToNodeArrayMap); + initialized = true; + } settled = true; } - public final void summarizeProfileIfPresent() { - nodeNetwork.summarizeProfileIfPresent(); + private void removeInactiveRootNodes(Map, AbstractRootNode[]> effectiveClassToNodeArrayMap) { + // Use getActiveNodes() for this, to not rerun the activity checking logic again. + effectiveClassToNodeArrayMap.replaceAll((k, v) -> Arrays.stream(v) + .filter(nodeNetwork.getActiveNodes()::contains) + .toArray(AbstractRootNode[]::new)); } + public Network_ getNodeNetwork() { + return nodeNetwork; + } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java deleted file mode 100644 index de5b3adfeaa..00000000000 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/NodeNetwork.java +++ /dev/null @@ -1,105 +0,0 @@ -package ai.timefold.solver.core.impl.bavet; - -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Stream; - -import ai.timefold.solver.core.api.domain.solution.PlanningSolution; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; -import ai.timefold.solver.core.impl.bavet.common.InnerConstraintProfiler; -import ai.timefold.solver.core.impl.bavet.common.Propagator; - -import org.jspecify.annotations.NullMarked; -import org.jspecify.annotations.Nullable; - -/** - * Represents Bavet's network of nodes, specific to a particular session. - * Nodes only used by disabled constraints have already been removed. - * - * @param declaredClassToNodeMap starting nodes, one for each class used in the constraints; - * root nodes, layer index 0. - * @param layeredNodes nodes grouped first by their layer, then by their index within the layer; - * propagation needs to happen in this order. - */ -@NullMarked -public record NodeNetwork(Map, List>> declaredClassToNodeMap, - Propagator[][] layeredNodes, @Nullable InnerConstraintProfiler constraintProfiler) { - - public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0], null); - - public int forEachNodeCount() { - return declaredClassToNodeMap.size(); - } - - public int layerCount() { - return layeredNodes.length; - } - - public Stream> getRootNodes() { - return declaredClassToNodeMap.entrySet() - .stream() - .flatMap(entry -> entry.getValue().stream()); - } - - public Stream> getRootNodesAcceptingType(Class factClass) { - // The node needs to match the fact, or the node needs to be applicable to the entire solution. - // The latter is for FromSolution nodes. - return declaredClassToNodeMap.entrySet() - .stream() - .flatMap(entry -> entry.getValue().stream()) - .filter(tupleSourceRoot -> factClass == PlanningSolution.class || tupleSourceRoot.allowsInstancesOf(factClass)); - } - - public void settle() { - for (var layerIndex = 0; layerIndex < layerCount(); layerIndex++) { - settleLayer(layeredNodes[layerIndex]); - } - } - - private static void settleLayer(Propagator[] nodesInLayer) { - var nodeCount = nodesInLayer.length; - if (nodeCount == 1) { - nodesInLayer[0].propagateEverything(); - } else { - for (var node : nodesInLayer) { - node.propagateRetracts(); - } - for (var node : nodesInLayer) { - node.propagateUpdates(); - } - for (var node : nodesInLayer) { - node.propagateInserts(); - } - } - } - - public void summarizeProfileIfPresent() { - if (constraintProfiler != null) { - constraintProfiler.summarize(); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof NodeNetwork that)) - return false; - return Objects.equals(declaredClassToNodeMap, that.declaredClassToNodeMap) - && Objects.deepEquals(layeredNodes, that.layeredNodes); - } - - @Override - public int hashCode() { - return Objects.hash(declaredClassToNodeMap, Arrays.deepHashCode(layeredNodes)); - } - - @Override - public String toString() { - return "%s with %d forEach nodes." - .formatted(getClass().getSimpleName(), forEachNodeCount()); - } - -} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractConcatNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractConcatNode.java index b60c374feaa..31946fe0b50 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractConcatNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractConcatNode.java @@ -36,6 +36,7 @@ protected AbstractConcatNode(TupleLifecycle nextNodesTupleLifecycle, int leftSourceTupleCloneStoreIndex, int rightSourceTupleCloneStoreIndex, int outputStoreSize) { + super(nextNodesTupleLifecycle); this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle); this.leftSourceTupleCloneStoreIndex = leftSourceTupleCloneStoreIndex; this.rightSourceTupleCloneStoreIndex = rightSourceTupleCloneStoreIndex; @@ -55,6 +56,13 @@ public StreamKind getStreamKind() { protected abstract void updateOutTupleFromRight(RightTuple_ rightTuple, OutTuple_ outTuple); + @Override + protected boolean canProduceTuples() { + // Unlike other two-input nodes, + // this node will produce tuples even if one of its inputs won't. + return leftCanProduceTuples || rightCanProduceTuples; + } + @Override public final void insertLeft(LeftTuple_ tuple) { var outTuple = getOutTupleFromLeft(tuple); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractFlattenNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractFlattenNode.java index aac78956af5..3302129f5a9 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractFlattenNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractFlattenNode.java @@ -14,13 +14,13 @@ import ai.timefold.solver.core.impl.bavet.common.tuple.TupleState; public abstract class AbstractFlattenNode - extends AbstractNode - implements TupleLifecycle { + extends AbstractSingleInputNode { private final int flattenStoreIndex; private final StaticPropagationQueue propagationQueue; protected AbstractFlattenNode(int flattenStoreIndex, TupleLifecycle nextNodesTupleLifecycle) { + super(nextNodesTupleLifecycle); this.flattenStoreIndex = flattenStoreIndex; this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle); } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractGroupNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractGroupNode.java index 3148738dbac..f0b272992af 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractGroupNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractGroupNode.java @@ -12,8 +12,7 @@ import ai.timefold.solver.core.impl.bavet.common.tuple.TupleState; public abstract class AbstractGroupNode - extends AbstractNode - implements TupleLifecycle { + extends AbstractSingleInputNode { private final int groupStoreIndex; /** @@ -56,6 +55,7 @@ protected AbstractGroupNode(int groupStoreIndex, Function g Supplier supplier, Function finisher, TupleLifecycle nextNodesTupleLifecycle, EnvironmentMode environmentMode) { + super(nextNodesTupleLifecycle); this.groupStoreIndex = groupStoreIndex; this.groupKeyFunction = groupKeyFunction; this.supplier = supplier; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractIfExistsNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractIfExistsNode.java index 498bc8c6467..7d9069c7dcd 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractIfExistsNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractIfExistsNode.java @@ -31,6 +31,7 @@ public abstract class AbstractIfExistsNode protected AbstractIfExistsNode(boolean shouldExist, TupleLifecycle nextNodesTupleLifecycle, boolean isFiltering, InTupleStorePositionTracker tupleStorePositionTracker) { + super(nextNodesTupleLifecycle); this.shouldExist = shouldExist; this.inputStoreIndexLeftTrackerList = isFiltering ? tupleStorePositionTracker.reserveNextLeft() : -1; this.inputStoreIndexRightTrackerList = isFiltering ? tupleStorePositionTracker.reserveNextRight() : -1; @@ -187,6 +188,24 @@ private void doRetractCounter(ExistsCounter counter) { } } + @Override + protected boolean canProduceTuples() { + // The left input must produce tuples no matter what, + // otherwise ifExists has nothing to join with. + if (!leftCanProduceTuples) { + return false; + } else if (shouldExist) { + // For the ifExists case, the right input must produce tuples as well, + // otherwise no left tuple can ever match. + return rightCanProduceTuples; + } else { + // For the ifNotExists case, if the right can not produce tuples, this node will. + // But even if right can produce tuples, it is not guaranteed to do so + // and therefore the node needs to stay active. + return true; + } + } + @Override public Propagator getPropagator() { return propagationQueue; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractJoinNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractJoinNode.java index 96ebe207c3e..f398fe68f4d 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractJoinNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractJoinNode.java @@ -32,6 +32,7 @@ public abstract class AbstractJoinNode nextNodesTupleLifecycle, boolean isFiltering, InOutTupleStorePositionTracker tupleStorePositionTracker) { + super(nextNodesTupleLifecycle); this.inputStoreIndexLeftOutTupleList = tupleStorePositionTracker.reserveNextLeft(); this.inputStoreIndexRightOutTupleList = tupleStorePositionTracker.reserveNextRight(); this.isFiltering = isFiltering; @@ -248,6 +249,11 @@ void retractOutTupleByRight(OutTuple_ outTuple) { propagateRetract(outTuple); } + @Override + protected boolean canProduceTuples() { + return leftCanProduceTuples && rightCanProduceTuples; + } + @Override public Propagator getPropagator() { return propagationQueue; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractMapNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractMapNode.java index 290225502c9..6fbc4ed2a12 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractMapNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractMapNode.java @@ -5,14 +5,14 @@ import ai.timefold.solver.core.impl.bavet.common.tuple.TupleState; public abstract class AbstractMapNode - extends AbstractNode - implements TupleLifecycle { + extends AbstractSingleInputNode { private final int inputStoreIndex; protected final int outputStoreSize; private final StaticPropagationQueue propagationQueue; protected AbstractMapNode(int inputStoreIndex, TupleLifecycle nextNodesTupleLifecycle, int outputStoreSize) { + super(nextNodesTupleLifecycle); this.inputStoreIndex = inputStoreIndex; this.outputStoreSize = outputStoreSize; this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNode.java index 76e776481cc..fafe6686158 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNode.java @@ -10,10 +10,15 @@ import org.jspecify.annotations.Nullable; /** + * Every node must be a child of one of the permitted extensions of this class. + * Direct extensions of this class may fail in unexpected places, + * as the specific type will be switched over. + * * @see PropagationQueue Description of the propagation mechanism. */ @NullMarked -public abstract class AbstractNode { +public abstract sealed class AbstractNode + permits AbstractRootNode, AbstractSingleInputNode, AbstractTwoInputNode { private long id; private long layerIndex = -1; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java index 97e78df8466..4543fc0d978 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractNodeBuildHelper.java @@ -1,29 +1,20 @@ package ai.timefold.solver.core.impl.bavet.common; -import static ai.timefold.solver.core.impl.bavet.common.ConstraintNodeProfileId.Qualifier; - -import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.TreeMap; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.UnaryOperator; -import ai.timefold.solver.core.impl.bavet.NodeNetwork; -import ai.timefold.solver.core.impl.bavet.common.tuple.AggregatedTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.InOutTupleStorePositionTracker; import ai.timefold.solver.core.impl.bavet.common.tuple.LeftTupleLifecycle; -import ai.timefold.solver.core.impl.bavet.common.tuple.ProfilingTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.RightTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; -import ai.timefold.solver.core.impl.score.stream.bavet.common.Scorer; import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; @@ -34,26 +25,18 @@ public abstract class AbstractNodeBuildHelper { private final Set activeStreamSet; private final Map nodeCreatorMap; private final Map> tupleLifecycleMap; - private final Map>> streamToProfileIdSets; private final Map storeIndexMap; - @Nullable - private final InnerConstraintProfiler constraintProfiler; - @Nullable private List reversedNodeList; - private long nextLifecycleProfilingId = 0; - protected AbstractNodeBuildHelper(Set activeStreamSet, - @Nullable InnerConstraintProfiler constraintProfiler) { + protected AbstractNodeBuildHelper(Set activeStreamSet) { this.activeStreamSet = activeStreamSet; int activeStreamSetSize = activeStreamSet.size(); - this.nodeCreatorMap = new HashMap<>(Math.max(16, activeStreamSetSize)); - this.tupleLifecycleMap = new HashMap<>(Math.max(16, activeStreamSetSize)); - this.storeIndexMap = new HashMap<>(Math.max(16, activeStreamSetSize / 2)); - this.streamToProfileIdSets = new HashMap<>(Math.max(16, activeStreamSetSize / 2)); + this.nodeCreatorMap = HashMap.newHashMap(Math.max(16, activeStreamSetSize)); + this.tupleLifecycleMap = HashMap.newHashMap(Math.max(16, activeStreamSetSize)); + this.storeIndexMap = HashMap.newHashMap(Math.max(16, activeStreamSetSize / 2)); this.reversedNodeList = new ArrayList<>(activeStreamSetSize); - this.constraintProfiler = constraintProfiler; } public boolean isStreamActive(Stream_ stream) { @@ -68,10 +51,9 @@ public void addNode(AbstractNode node, Stream_ creator, @Nullable Stream_ parent reversedNodeList.add(node); node.addLocationSet(creator.getLocationSet()); nodeCreatorMap.put(node, creator); - if (!(node instanceof BavetRootNode)) { + if (!(node instanceof AbstractRootNode)) { if (parent == null) { - throw new IllegalStateException("Impossible state: The node (%s) has no parent (%s)." - .formatted(node, parent)); + throw new IllegalStateException("Impossible state: The node (%s) has no parent (%s).".formatted(node, parent)); } putInsertUpdateRetract(parent, (TupleLifecycle) node); } @@ -85,55 +67,17 @@ public void addNode(AbstractNode node, Stream_ creator, Stream_ leftParent, Stre putInsertUpdateRetract(rightParent, TupleLifecycle.ofRight((RightTupleLifecycle) node)); } - private void updateConstraintProfileIdSet(Stream_ stream, TupleLifecycle tupleLifecycle) { - if (tupleLifecycle instanceof ProfilingTupleLifecycle profilingTupleLifecycle) { - var affectedSets = streamToProfileIdSets.getOrDefault(stream, Collections.emptyList()); - for (var affectedSet : affectedSets) { - affectedSet.add(profilingTupleLifecycle.profileId()); - } - } else if (tupleLifecycle instanceof AggregatedTupleLifecycle aggregated) { - for (var innerLifecycle : aggregated.lifecycles()) { - updateConstraintProfileIdSet(stream, innerLifecycle); - } - } + protected Stream_ getNodeCreator(AbstractNode node) { + return nodeCreatorMap.get(node); } - public void putInsertUpdateRetract(Stream_ stream, - TupleLifecycle tupleLifecycle) { - if (constraintProfiler != null) { - var out = TupleLifecycle.profiling(constraintProfiler, nextLifecycleProfilingId, - stream, tupleLifecycle); - tupleLifecycleMap.put(stream, out); - updateConstraintProfileIdSet(stream, out); + @SuppressWarnings("unchecked") + protected TupleLifecycle getTupleLifecycle(Stream_ stream) { + return (TupleLifecycle) tupleLifecycleMap.get(stream); + } - if (tupleLifecycle instanceof Scorer scorer) { - // This is a scorer, so we can navigate up its parents - // to find all locations corresponding to this constraint - var queue = new ArrayDeque(); - var constraintSet = new LinkedHashSet(); - queue.add(stream); - while (!queue.isEmpty()) { - var currentStream = queue.poll(); - var streamSets = - streamToProfileIdSets.computeIfAbsent((Stream_) currentStream, ignored -> new ArrayList<>()); - streamSets.add(constraintSet); - var lifecycle = tupleLifecycleMap.get(currentStream); - if (lifecycle instanceof ProfilingTupleLifecycle profilingTupleLifecycle) { - constraintSet.add(profilingTupleLifecycle.profileId()); - } - if (currentStream instanceof BavetStreamBinaryOperation binaryOperation) { - queue.add(binaryOperation.getLeftParent()); - queue.add(binaryOperation.getRightParent()); - } else if (currentStream.getParent() != null) { - queue.add(currentStream.getParent()); - } - } - constraintProfiler.registerConstraint(scorer.getConstraintRef(), constraintSet); - } - nextLifecycleProfilingId++; - } else { - tupleLifecycleMap.put(stream, tupleLifecycle); - } + public void putInsertUpdateRetract(Stream_ stream, TupleLifecycle tupleLifecycle) { + tupleLifecycleMap.put(stream, tupleLifecycle); } public void putInsertUpdateRetract(Stream_ stream, List childStreamList, @@ -143,16 +87,12 @@ public void putInsertUpdateRetract(Stream_ stream, List TupleLifecycle - getAggregatedTupleLifecycle(List streamList) { - var tupleLifecycles = streamList.stream() - .filter(this::isStreamActive) - .map(s -> getTupleLifecycle(s, tupleLifecycleMap)) + public TupleLifecycle getAggregatedTupleLifecycle(List streamList) { + var tupleLifecycles = streamList.stream().filter(this::isStreamActive).map(s -> getTupleLifecycle(s, tupleLifecycleMap)) .toArray(TupleLifecycle[]::new); return switch (tupleLifecycles.length) { - case 0 -> - throw new IllegalStateException("Impossible state: None of the streamList (%s) are active." - .formatted(streamList)); + case 0 -> throw new IllegalStateException( + "Impossible state: None of the streamList (%s) are active.".formatted(streamList)); case 1 -> tupleLifecycles[0]; default -> TupleLifecycle.aggregate(tupleLifecycles); }; @@ -163,8 +103,7 @@ private static TupleLifecycle getTupleLi Map> tupleLifecycleMap) { var tupleLifecycle = (TupleLifecycle) tupleLifecycleMap.get(stream); if (tupleLifecycle == null) { - throw new IllegalStateException("Impossible state: the stream (%s) hasn't built a node yet." - .formatted(stream)); + throw new IllegalStateException("Impossible state: the stream (%s) hasn't built a node yet.".formatted(stream)); } return tupleLifecycle; } @@ -206,8 +145,8 @@ public Stream_ getNodeCreatingStream(AbstractNode node) { public AbstractNode findParentNode(Stream_ childNodeCreator) { if (childNodeCreator == null) { // We've recursed to the bottom without finding a parent node. - throw new IllegalStateException("Impossible state: node-creating stream (%s) has no parent node." - .formatted(childNodeCreator)); + throw new IllegalStateException( + "Impossible state: node-creating stream (%s) has no parent node.".formatted(childNodeCreator)); } // Look the stream up among node creators and if found, the node is the parent node. for (Map.Entry entry : this.nodeCreatorMap.entrySet()) { @@ -220,38 +159,6 @@ public AbstractNode findParentNode(Stream_ childNodeCreator) { return findParentNode(childNodeCreator.getParent()); } - public static NodeNetwork buildNodeNetwork(List nodeList, - Map, List>> declaredClassToNodeMap, - AbstractNodeBuildHelper nodeBuildHelper) { - var layerMap = new TreeMap>(); - var profiler = nodeBuildHelper.constraintProfiler; - for (var node : nodeList) { - var layer = node.getLayerIndex(); - var propagator = node.getPropagator(); - if (profiler != null) { - var profileKey = nodeBuildHelper.nextLifecycleProfilingId; - nodeBuildHelper.nextLifecycleProfilingId++; - var profileId = - new ConstraintNodeProfileId(profileKey, node.getStreamKind(), Qualifier.NODE, node.getLocationSet()); - nodeBuildHelper.constraintProfiler.register(profileId); - propagator = new ProfilingPropagator(profiler, profileId, propagator); - var stream = nodeBuildHelper.nodeCreatorMap.get(node); - for (var affectedSet : nodeBuildHelper.streamToProfileIdSets.get(stream)) { - affectedSet.add(profileId); - } - } - layerMap.computeIfAbsent(layer, k -> new ArrayList<>()) - .add(propagator); - } - var layerCount = layerMap.size(); - var layeredNodes = new Propagator[layerCount][]; - for (var i = 0; i < layerCount; i++) { - var layer = layerMap.get((long) i); - layeredNodes[i] = layer.toArray(new Propagator[0]); - } - return new NodeNetwork(declaredClassToNodeMap, layeredNodes, nodeBuildHelper.constraintProfiler); - } - public > List buildNodeList(Set streamSet, BuildHelper_ buildHelper, BiConsumer nodeBuilder, Consumer nodeProcessor) { /* @@ -295,20 +202,22 @@ public > List long determineLayerIndex(AbstractNode node, AbstractNodeBuildHelper buildHelper) { - if (node instanceof BavetRootNode) { // Root nodes, and only they, are in layer 0. - return 0; - } else if (node instanceof AbstractTwoInputNode joinNode) { - var nodeCreator = (BavetStreamBinaryOperation) buildHelper.getNodeCreatingStream(joinNode); - var leftParent = (Stream_) nodeCreator.getLeftParent(); - var rightParent = (Stream_) nodeCreator.getRightParent(); - var leftParentNode = buildHelper.findParentNode(leftParent); - var rightParentNode = buildHelper.findParentNode(rightParent); - return Math.max(leftParentNode.getLayerIndex(), rightParentNode.getLayerIndex()) + 1; - } else { - var nodeCreator = buildHelper.getNodeCreatingStream(node); - var parentNode = buildHelper.findParentNode(nodeCreator.getParent()); - return parentNode.getLayerIndex() + 1; - } + return switch (node) { + case AbstractRootNode ignored -> 0L; // Root nodes, and only they, are in layer 0. + case AbstractTwoInputNode twoInputNode -> { // Two-input nodes must sit above both inputs. + var nodeCreator = (BavetStreamBinaryOperation) buildHelper.getNodeCreatingStream(twoInputNode); + var leftParent = (Stream_) nodeCreator.getLeftParent(); + var rightParent = (Stream_) nodeCreator.getRightParent(); + var leftParentNode = buildHelper.findParentNode(leftParent); + var rightParentNode = buildHelper.findParentNode(rightParent); + yield Math.max(leftParentNode.getLayerIndex(), rightParentNode.getLayerIndex()) + 1; + } + default -> { // Every other node sits above its parent. + var nodeCreator = buildHelper.getNodeCreatingStream(node); + var parentNode = buildHelper.findParentNode(nodeCreator.getParent()); + yield parentNode.getLayerIndex() + 1; + } + }; } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java index 8f4b94b87ff..078fadde2c7 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractPrecomputeNode.java @@ -10,20 +10,33 @@ import org.jspecify.annotations.Nullable; @NullMarked -public abstract class AbstractPrecomputeNode extends AbstractNode - implements BavetRootNode { +public abstract class AbstractPrecomputeNode + extends AbstractRootNode { + + private final TupleLifecycle downstreamTupleLifecycle; private final RecordAndReplayPropagator recordAndReplayPropagator; private final Class[] sourceClasses; protected AbstractPrecomputeNode(Supplier> precomputeBuildHelperSupplier, TupleLifecycle nextNodesTupleLifecycle, Class[] sourceClasses) { + this.downstreamTupleLifecycle = nextNodesTupleLifecycle; this.recordAndReplayPropagator = new RecordAndReplayPropagator<>(precomputeBuildHelperSupplier, this::remapTuple, nextNodesTupleLifecycle); this.sourceClasses = sourceClasses; } + @Override + public void afterAllFactsInserted(boolean unused) { + downstreamTupleLifecycle.afterAllFactsInserted(recordAndReplayPropagator.canProduceTuples()); + } + + @Override + public boolean isActive() { + return recordAndReplayPropagator.canProduceTuples() && downstreamTupleLifecycle.isActive(); + } + @Override public StreamKind getStreamKind() { return StreamKind.PRECOMPUTE; @@ -50,7 +63,7 @@ public final Class[] getSourceClasses() { } @Override - public final boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { + public final boolean supports(AbstractRootNode.LifecycleOperation lifecycleOperation) { return true; } @@ -79,4 +92,5 @@ public final void retract(@Nullable Object a) { } protected abstract Tuple_ remapTuple(Tuple_ tuple); + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractRootNode.java similarity index 73% rename from core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java rename to core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractRootNode.java index 13aad93bf7d..edf56b2fe46 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/BavetRootNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractRootNode.java @@ -1,19 +1,24 @@ package ai.timefold.solver.core.impl.bavet.common; +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; + import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; @NullMarked -public interface BavetRootNode { - void insert(@Nullable A a); +public abstract non-sealed class AbstractRootNode + extends AbstractNode + implements ActivitySupport { + + public abstract void insert(@Nullable A a); - void update(@Nullable A a); + public abstract void update(@Nullable A a); - void retract(@Nullable A a); + public abstract void retract(@Nullable A a); - boolean allowsInstancesOf(Class clazz); + public abstract boolean allowsInstancesOf(Class clazz); - Class[] getSourceClasses(); + public abstract Class[] getSourceClasses(); /** * Determines if this node supports the given lifecycle operation. @@ -22,13 +27,13 @@ public interface BavetRootNode { * @param lifecycleOperation the lifecycle operation to check * @return {@code true} if the given lifecycle operation is supported; otherwise, {@code false}. */ - boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation); + public abstract boolean supports(LifecycleOperation lifecycleOperation); /** * Represents the various lifecycle operations that can be performed * on tuples within a node in Bavet. */ - enum LifecycleOperation { + public enum LifecycleOperation { /** * Represents the operation of inserting a new tuple into the node. * This operation is typically performed when a new fact is added to the working solution diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractSingleInputNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractSingleInputNode.java new file mode 100644 index 00000000000..76f714fdea9 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractSingleInputNode.java @@ -0,0 +1,29 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +public abstract non-sealed class AbstractSingleInputNode + extends AbstractNode + implements TupleLifecycle { + + private final TupleLifecycle downstreamTupleLifecycle; + + protected AbstractSingleInputNode(TupleLifecycle downstreamTupleLifecycle) { + this.downstreamTupleLifecycle = downstreamTupleLifecycle; + } + + private boolean upstreamCanProduceTuples; + + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { // We only delegate; implementations can override. + this.upstreamCanProduceTuples = upstreamCanProduceTuples; + downstreamTupleLifecycle.afterAllFactsInserted(upstreamCanProduceTuples); + } + + @Override + public boolean isActive() { + return upstreamCanProduceTuples && downstreamTupleLifecycle.isActive(); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractTwoInputNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractTwoInputNode.java index 5e3af556991..8bf3e7a6fde 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractTwoInputNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/AbstractTwoInputNode.java @@ -3,9 +3,53 @@ import ai.timefold.solver.core.impl.bavet.common.tuple.LeftTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.RightTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; -public abstract class AbstractTwoInputNode +public abstract non-sealed class AbstractTwoInputNode extends AbstractNode implements LeftTupleLifecycle, RightTupleLifecycle { + private final TupleLifecycle downstreamTupleLifecycle; + + private boolean leftInitialized; + private boolean rightInitialized; + protected boolean fullyInitialized; + protected boolean leftCanProduceTuples; + protected boolean rightCanProduceTuples; + + protected AbstractTwoInputNode(TupleLifecycle downstreamTupleLifecycle) { + this.downstreamTupleLifecycle = downstreamTupleLifecycle; + + } + + @Override + public final void afterAllFactsInsertedLeft(boolean upstreamCanProduceTuples) { + leftCanProduceTuples = upstreamCanProduceTuples; + if (!fullyInitialized && rightInitialized) { + // Only initialize downstream nodes when we have received initialization from both parents. + // Avoid initializing twice. + downstreamTupleLifecycle.afterAllFactsInserted(canProduceTuples()); + fullyInitialized = true; + } + leftInitialized = true; + } + + protected abstract boolean canProduceTuples(); + + @Override + public final void afterAllFactsInsertedRight(boolean upstreamCanProduceTuples) { + rightCanProduceTuples = upstreamCanProduceTuples; + if (!fullyInitialized && leftInitialized) { + // Only initialize downstream nodes when we have received initialization from both parents. + // Avoid initializing twice. + downstreamTupleLifecycle.afterAllFactsInserted(canProduceTuples()); + fullyInitialized = true; + } + rightInitialized = true; + } + + @Override + public final boolean isActive() { + return canProduceTuples() && downstreamTupleLifecycle.isActive(); + } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/InnerConstraintProfiler.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/InnerConstraintProfiler.java index a3f33c5592b..d2bb7847312 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/InnerConstraintProfiler.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/InnerConstraintProfiler.java @@ -1,7 +1,10 @@ package ai.timefold.solver.core.impl.bavet.common; +import java.util.List; import java.util.Set; +import java.util.function.Function; +import ai.timefold.solver.core.api.score.stream.Constraint; import ai.timefold.solver.core.api.score.stream.ConstraintRef; import org.jspecify.annotations.NullMarked; @@ -11,6 +14,10 @@ public interface InnerConstraintProfiler { void register(ConstraintNodeProfileId profileId); + void registerNodeGraph(Solution_ solution, List nodeList, + Set constraintSet, Function nodeToStreamFunction, + Function streamToParentNodeFunction); + void registerConstraint(ConstraintRef constraintRef, Set profileIdSet); void measure(ConstraintNodeProfileId profileId, Operation operation, Runnable measurable); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/ProfilingPropagator.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/ProfilingPropagator.java index 685e8638995..4c82d73804e 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/ProfilingPropagator.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/ProfilingPropagator.java @@ -2,6 +2,7 @@ public record ProfilingPropagator(InnerConstraintProfiler profiler, ConstraintNodeProfileId profileId, Propagator delegate) implements Propagator { + @Override public void propagateRetracts() { profiler.measure(profileId, diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java index 96c0b3f04dc..a4287723ff3 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/Propagator.java @@ -38,7 +38,8 @@ * * @see PropagationQueue More information about propagation. */ -public sealed interface Propagator permits ProfilingPropagator, PropagationQueue, RecordAndReplayPropagator { +public sealed interface Propagator + permits ProfilingPropagator, PropagationQueue, RecordAndReplayPropagator { /** * Starts the propagation event. Must be followed by {@link #propagateUpdates()}. diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java index 51e436b51ed..fe76abd7421 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/RecordAndReplayPropagator.java @@ -10,7 +10,7 @@ import java.util.function.Supplier; import java.util.function.UnaryOperator; -import ai.timefold.solver.core.impl.bavet.NodeNetwork; +import ai.timefold.solver.core.impl.bavet.AbstractBavetNodeNetwork; import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; @@ -22,13 +22,14 @@ /** * The implementation records the tuples each object affects inside - * an internal {@link NodeNetwork} and replays them on update. + * an internal {@link AbstractBavetNodeNetwork} and replays them on update. * Used by {@link AbstractPrecomputeNode} to precompute constraint streams. * * @param */ @NullMarked -public final class RecordAndReplayPropagator implements Propagator { +public final class RecordAndReplayPropagator + implements Propagator { private final Set retractQueue; private final Set insertQueue; @@ -75,6 +76,15 @@ public RecordAndReplayPropagator(Supplier> pr this(precomputeBuildHelperSupplier, internalTupleToOutputTupleMapper, nextNodesTupleLifecycle, 1000); } + public boolean canProduceTuples() { + // This is correct, but not optimal. + // These conditions guarantee that deactivation will only happen when safe, + // but it will not deactivate in all cases; + // for that, the activity check would have to happen after all inserts, updates and retracts were propagated once. + return !objectToOutputTuplesMap.isEmpty() // Tuples were produced. + || !insertQueue.isEmpty(); // Tuples will be produced, unless retract removes them from the insert queue. + } + public void insert(Object object) { // do not remove a retract of the same fact (a fact was updated) insertQueue.add(object); @@ -108,7 +118,7 @@ public void propagateRetracts() { if (!retractQueue.isEmpty() || !insertQueue.isEmpty()) { var precomputeBuildHelper = precomputeBuildHelperSupplier.get(); var internalNodeNetwork = precomputeBuildHelper.getNodeNetwork(); - var objectClassToRootNodes = new HashMap, List>>(); + var objectClassToRootNodes = new HashMap, List>>(); var recordingTupleLifecycle = precomputeBuildHelper.getRecordingTupleLifecycle(); invalidateCache(); @@ -154,10 +164,10 @@ public void propagateRetracts() { } @SuppressWarnings({ "unchecked", "rawtypes" }) - private static List> getRootNodes(Object object, NodeNetwork internalNodeNetwork, - Map, List>> objectClassToRootNodes) { + private static List> getRootNodes(Object object, AbstractBavetNodeNetwork internalNodeNetwork, + Map, List>> objectClassToRootNodes) { return (List) objectClassToRootNodes.computeIfAbsent(object.getClass(), clazz -> { - var out = new ArrayList>(); + var out = new ArrayList>(); internalNodeNetwork.getRootNodesAcceptingType(object.getClass()).forEach(out::add); return out; }); @@ -206,7 +216,8 @@ private void invalidateCache() { factOutputTupleList.clear(); } - private void recalculateTuples(NodeNetwork internalNodeNetwork, Map, List>> classToRootNodeList, + private void recalculateTuples(AbstractBavetNodeNetwork internalNodeNetwork, + Map, List>> classToRootNodeList, RecordingTupleLifecycle recordingTupleLifecycle) { var internalTupleToOutputTupleMap = new IdentityHashMap(seenEntitySet.size() + seenFactSet.size()); @@ -217,7 +228,7 @@ private void recalculateTuples(NodeNetwork internalNodeNetwork, Map, Li // Do a fake update on the object and settle the network; this will update precisely the // tuples mapped to this node, which will then be recorded classToRootNodeList.get(invalidated.getClass()) - .forEach(node -> ((BavetRootNode) node).update(invalidated)); + .forEach(node -> ((AbstractRootNode) node).update(invalidated)); internalNodeNetwork.settle(); } if (mappedTuples.isEmpty()) { @@ -242,7 +253,7 @@ private void recalculateTuples(NodeNetwork internalNodeNetwork, Map, Li internalTupleToOutputTupleMapper, internalTupleToOutputTupleMap))) { for (var fact : seenFactSet) { classToRootNodeList.get(fact.getClass()) - .forEach(node -> ((BavetRootNode) node).update(fact)); + .forEach(node -> ((AbstractRootNode) node).update(fact)); } internalNodeNetwork.settle(); } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ActivitySupport.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ActivitySupport.java new file mode 100644 index 00000000000..4271acc9363 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ActivitySupport.java @@ -0,0 +1,108 @@ +package ai.timefold.solver.core.impl.bavet.common.tuple; + +import ai.timefold.solver.core.api.solver.change.ProblemChange; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; +import ai.timefold.solver.core.impl.bavet.common.Propagator; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintSession; +import ai.timefold.solver.core.impl.score.stream.bavet.common.Scorer; + +/** + * When a Bavet session is created, some nodes can become inactive. + * An inactive node cannot impact the score, as it cannot produce tuples. + * Yet processing inserts, updates and retracts in inactive nodes still consumes CPU cycles. + * This interface establishes a protocol through which nodes can signal at runtime + * that they are inactive, so that the session can ignore them. + *

+ * The pattern of interaction is as follows: + *

    + *
  • The session will first call {@link #afterAllFactsInserted(boolean)} on every {@link AbstractRootNode}. + * Each lifecycle must propagate the call further downstream, until a {@link Scorer} is reached. + * (In case of Constraint Streams.)
  • + *
  • Then the session will call {@link #isActive()} on every single lifecycle, + * and entirely remove deactivated lifecycles from propagation. + * Each lifecycle makes its activity decision independently, + * but may ask its downstream lifecycles.
  • + *
  • Once these two steps have been executed, + * no method of this interface will ever be called again + * for the duration of the session.
  • + *
+ * + * @see TupleLifecycle + * @see LeftTupleLifecycle + * @see RightTupleLifecycle + */ +public interface ActivitySupport { + + /** + * Triggered after all facts which will ever be inserted have been inserted to the session; + * it doesn't guarantee they were propagated as far as this lifecycle, + * but the session now carries all the facts it will ever carry. + * (The only way to insert or retract a fact is through a {@link ProblemChange}, + * and that will nuke the score director.) + *

+ * It is the responsibility of the lifecycle to trigger initialization + * of all of its downstream lifecycles, should there be any. + * It must first decide for itself if it can produce tuples based on what it learned from upstream, + * and then propagate that information downstream so that they can make their own activation decisions. + *

+ * When deciding whether a lifecycle can produce tuples, consider the following: + *

    + *
  • + * Typically, when upstream cannot produce tuples, neither can downstream. + * (Unless downstream has the capability to fabricate tuples out of nowhere.) + *
  • + *
  • + * Do not make decisions based on whether upstream actually produced any tuples by this point. + * If upstream produced no tuples so far, it doesn't mean it will never produce any. + * Filters on variables which previously did not match can easily create tuples during tuple updates. + *
  • + *
+ * + * @param upstreamCanProduceTuples True if the upstream lifecycle(s) will produce any tuples. + * If false, this lifecycle will never receive any tuples and can most likely deactivate itself. + */ + void afterAllFactsInserted(boolean upstreamCanProduceTuples); + + /** + * Lifecycles which are considered inactive will never be called by their {@link Propagator}. + * A lifecycle is considered inactive if: + * + *
    + *
  • It will not produce a tuple. + * (Such as forEach(MyClass), when no MyClass instances were inserted.)
  • + *
  • Its downstream tuples are inactive. + * (A forEach() itself may be able to produce a tuple, + * but a join() downstream cannot, because its other side cannot produce tuples. + * In this case, the left side must be deactivated as well, + * unless it also outputs to another active downstream lifecycle.) + *
  • + *
+ * + * A decision on whether a lifecycle is active can only be made when the following information is available: + * + *
    + *
  • Upstream has informed us whether they can produce tuples. + * This happens through {@link #afterAllFactsInserted(boolean) initialization}
  • + *
  • Downstream has completed its initialization. + * It is a lifecycle's responsibility to trigger downstream initialization.
  • + *
+ * + * This is a one-time decision; + * for each lifecycle, this method will only be called at the start of {@link BavetConstraintSession}, + * and never again. + * It may be called multiple times by different upstream lifecycles, + * and must always return the same value. + * + *

+ * Typically this decision will be {@code upstreamCanProduceTuples && downstreamIsActive}, + * but some specialized lifecycles may have to use different logic. + * + * @return true if this lifecycle can produce tuples + */ + default boolean isActive() { + throw new IllegalStateException( + "Impossible state: lifecycle (%s) not yet initialized (afterAllFactsInserted not called)." + .formatted(this)); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/AggregatedTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/AggregatedTupleLifecycle.java index ef58d7a0094..ba8374c0600 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/AggregatedTupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/AggregatedTupleLifecycle.java @@ -1,51 +1,90 @@ package ai.timefold.solver.core.impl.bavet.common.tuple; import java.util.Arrays; -import java.util.Objects; -public record AggregatedTupleLifecycle(TupleLifecycle... lifecycles) - implements - TupleLifecycle { +public final class AggregatedTupleLifecycle + implements TupleLifecycle { + + // Iterating a list in update() was measurably slower in micro benchmarks, so we deal with arrays. + private TupleLifecycle[] downstream; + private boolean upstreamCanProduceTuples; + private boolean downstreamFinal; @SafeVarargs - public AggregatedTupleLifecycle { - // Exists so that we have something to put the @SafeVarargs annotation on. + public AggregatedTupleLifecycle(TupleLifecycle... downstream) { + this.downstream = downstream; + } + + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + for (var lifecycle : downstream) { // First initialize all downstream lifecycles. + lifecycle.afterAllFactsInserted(upstreamCanProduceTuples); + } + this.upstreamCanProduceTuples = upstreamCanProduceTuples; + } + + @SuppressWarnings("unchecked") + @Override + public boolean isActive() { + if (downstreamFinal) { + return downstream.length > 0; + } + if (upstreamCanProduceTuples) { + downstream = Arrays.stream(downstream) + .distinct() + .filter(TupleLifecycle::isActive) + .toArray(TupleLifecycle[]::new); + } else { + // No upstream facts, so downstream lifecycles will never be active. + downstream = new TupleLifecycle[0]; + } + downstreamFinal = true; + return downstream.length > 0; } @Override public void insert(Tuple_ tuple) { - for (var lifecycle : lifecycles) { + for (var lifecycle : downstream) { lifecycle.insert(tuple); } } @Override public void update(Tuple_ tuple) { - for (var lifecycle : lifecycles) { + for (var lifecycle : downstream) { lifecycle.update(tuple); } } @Override public void retract(Tuple_ tuple) { - for (var lifecycle : lifecycles) { + for (var lifecycle : downstream) { lifecycle.retract(tuple); } } + /** + * Users must not modify this array. (Defensive copy avoided for performance reasons.) + * + * @return active downstream lifecycles + */ + public TupleLifecycle[] downstream() { + return downstream; + } + @Override public boolean equals(Object o) { return o instanceof AggregatedTupleLifecycle that && - Objects.deepEquals(lifecycles, that.lifecycles); + Arrays.deepEquals(downstream, that.downstream); } @Override public int hashCode() { - return Arrays.hashCode(lifecycles); + return Arrays.deepHashCode(downstream); } @Override public String toString() { - return "size = " + lifecycles.length; + return "size = " + downstream.length; } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ConditionalTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ConditionalTupleLifecycle.java index 87c79f6f14d..59ea3e2133f 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ConditionalTupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ConditionalTupleLifecycle.java @@ -3,14 +3,32 @@ import java.util.Objects; import java.util.function.Predicate; -public record ConditionalTupleLifecycle(TupleLifecycle downstreamLifecycle, - TuplePredicate predicate) - implements - TupleLifecycle { +import org.jspecify.annotations.NullMarked; - public ConditionalTupleLifecycle { - Objects.requireNonNull(downstreamLifecycle); - Objects.requireNonNull(predicate); +@NullMarked +public final class ConditionalTupleLifecycle + implements TupleLifecycle { + + private final TupleLifecycle downstreamLifecycle; + private final TuplePredicate predicate; + private boolean isActive; + + public ConditionalTupleLifecycle(TupleLifecycle downstreamLifecycle, TuplePredicate predicate) { + this.downstreamLifecycle = Objects.requireNonNull(downstreamLifecycle); + this.predicate = Objects.requireNonNull(predicate); + } + + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + // It is possible the predicate will always filter everything out, but we cannot know that for certain. + // We must pass the upstream information downstream, and be active if upstream can send anything to us. + this.isActive = upstreamCanProduceTuples; + downstreamLifecycle.afterAllFactsInserted(upstreamCanProduceTuples); + } + + @Override + public boolean isActive() { + return isActive && downstreamLifecycle.isActive(); } @Override @@ -34,13 +52,29 @@ public void retract(Tuple_ tuple) { downstreamLifecycle.retract(tuple); } + public TuplePredicate predicate() { + return predicate; + } + @Override public String toString() { return "Conditional %s".formatted(downstreamLifecycle); } + @Override + public boolean equals(Object obj) { + return obj instanceof ConditionalTupleLifecycle other + && Objects.equals(this.downstreamLifecycle, other.downstreamLifecycle) + && Objects.equals(this.predicate, other.predicate); + } + + @Override + public int hashCode() { + return Objects.hash(downstreamLifecycle, predicate); + } + @FunctionalInterface - interface TuplePredicate extends Predicate { + public interface TuplePredicate extends Predicate { } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftBridgeTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftBridgeTupleLifecycle.java new file mode 100644 index 00000000000..008fee39270 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftBridgeTupleLifecycle.java @@ -0,0 +1,64 @@ +package ai.timefold.solver.core.impl.bavet.common.tuple; + +import java.util.Objects; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +final class LeftBridgeTupleLifecycle + implements TupleLifecycle { + + private final LeftTupleLifecycle leftTupleLifecycle; + private boolean isActive; + + LeftBridgeTupleLifecycle(LeftTupleLifecycle leftTupleLifecycle) { + this.leftTupleLifecycle = Objects.requireNonNull(leftTupleLifecycle); + } + + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + this.isActive = upstreamCanProduceTuples; // We're just delegating. + leftTupleLifecycle.afterAllFactsInsertedLeft(upstreamCanProduceTuples); + } + + @Override + public boolean isActive() { + return isActive && leftTupleLifecycle.isActive(); + } + + @Override + public void insert(Tuple_ tuple) { + leftTupleLifecycle.insertLeft(tuple); + } + + @Override + public void update(Tuple_ tuple) { + leftTupleLifecycle.updateLeft(tuple); + } + + @Override + public void retract(Tuple_ tuple) { + leftTupleLifecycle.retractLeft(tuple); + } + + public LeftTupleLifecycle leftTupleLifecycle() { + return leftTupleLifecycle; + } + + @Override + public String toString() { + return "left %s".formatted(leftTupleLifecycle); + } + + @Override + public boolean equals(Object o) { + return o instanceof LeftBridgeTupleLifecycle other + && Objects.equals(this.leftTupleLifecycle, other.leftTupleLifecycle); + } + + @Override + public int hashCode() { + return Objects.hash(leftTupleLifecycle); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftTupleLifecycle.java index 33ca38918b9..dc99affd672 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftTupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftTupleLifecycle.java @@ -2,6 +2,16 @@ public interface LeftTupleLifecycle { + /** + * As defined by {@link ActivitySupport#afterAllFactsInserted}. + */ + void afterAllFactsInsertedLeft(boolean upstreamCanProduceTuples); + + /** + * As defined by {@link ActivitySupport#isActive()}. + */ + boolean isActive(); + void insertLeft(Tuple_ tuple); void updateLeft(Tuple_ tuple); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftTupleLifecycleImpl.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftTupleLifecycleImpl.java deleted file mode 100644 index efe6db29fe5..00000000000 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/LeftTupleLifecycleImpl.java +++ /dev/null @@ -1,33 +0,0 @@ -package ai.timefold.solver.core.impl.bavet.common.tuple; - -import java.util.Objects; - -record LeftTupleLifecycleImpl(LeftTupleLifecycle leftTupleLifecycle) - implements - TupleLifecycle { - - LeftTupleLifecycleImpl { - Objects.requireNonNull(leftTupleLifecycle); - } - - @Override - public void insert(Tuple_ tuple) { - leftTupleLifecycle.insertLeft(tuple); - } - - @Override - public void update(Tuple_ tuple) { - leftTupleLifecycle.updateLeft(tuple); - } - - @Override - public void retract(Tuple_ tuple) { - leftTupleLifecycle.retractLeft(tuple); - } - - @Override - public String toString() { - return "left " + leftTupleLifecycle; - } - -} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ProfilingTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ProfilingTupleLifecycle.java index 78951e99fbf..e8597bbbd2b 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ProfilingTupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/ProfilingTupleLifecycle.java @@ -1,14 +1,36 @@ package ai.timefold.solver.core.impl.bavet.common.tuple; +import java.util.Objects; + import ai.timefold.solver.core.impl.bavet.common.ConstraintNodeProfileId; import ai.timefold.solver.core.impl.bavet.common.InnerConstraintProfiler; -public record ProfilingTupleLifecycle( - InnerConstraintProfiler constraintProfiler, - ConstraintNodeProfileId profileId, - TupleLifecycle delegate) implements TupleLifecycle { - public ProfilingTupleLifecycle { +import org.jspecify.annotations.NullMarked; + +@NullMarked +public final class ProfilingTupleLifecycle + implements TupleLifecycle { + + private final InnerConstraintProfiler constraintProfiler; + private final ConstraintNodeProfileId profileId; + private final TupleLifecycle delegate; + + public ProfilingTupleLifecycle(InnerConstraintProfiler constraintProfiler, ConstraintNodeProfileId profileId, + TupleLifecycle delegate) { constraintProfiler.register(profileId); + this.constraintProfiler = constraintProfiler; + this.profileId = profileId; + this.delegate = delegate; + } + + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + this.delegate.afterAllFactsInserted(upstreamCanProduceTuples); + } + + @Override + public boolean isActive() { + return delegate.isActive(); } @Override @@ -28,4 +50,35 @@ public void retract(Tuple_ tuple) { constraintProfiler.measure(profileId, InnerConstraintProfiler.Operation.RETRACT, () -> delegate.retract(tuple)); } + + public ConstraintNodeProfileId profileId() { + return profileId; + } + + public TupleLifecycle delegate() { + return delegate; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + return obj instanceof ProfilingTupleLifecycle other + && Objects.equals(constraintProfiler, other.constraintProfiler) + && Objects.equals(profileId, other.profileId) + && Objects.equals(delegate, other.delegate); + } + + @Override + public int hashCode() { + return Objects.hash(constraintProfiler, profileId, delegate); + } + + @Override + public String toString() { + return "ProfilingTupleLifecycle[%s, %s, %s]" + .formatted(constraintProfiler, profileId, delegate); + } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java index 23c9b34e65e..0ab2d33812a 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RecordingTupleLifecycle.java @@ -6,7 +6,9 @@ import org.jspecify.annotations.Nullable; @NullMarked -public class RecordingTupleLifecycle implements TupleLifecycle, AutoCloseable { +public class RecordingTupleLifecycle + implements TupleLifecycle, AutoCloseable { + @Nullable TupleRecorder tupleRecorder; @@ -20,12 +22,20 @@ public void close() { this.tupleRecorder = null; } + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + // Nothing to propagate to. + } + + @Override + public boolean isActive() { + return true; // Always active. + } + @Override public void insert(Tuple_ tuple) { if (tupleRecorder != null) { - throw new IllegalStateException(""" - Impossible state: tuple %s was inserted during recording. - """.formatted(tuple)); + throw new IllegalStateException("Impossible state: tuple %s was inserted during recording.".formatted(tuple)); } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightBridgeTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightBridgeTupleLifecycle.java new file mode 100644 index 00000000000..fd2e18f4f57 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightBridgeTupleLifecycle.java @@ -0,0 +1,64 @@ +package ai.timefold.solver.core.impl.bavet.common.tuple; + +import java.util.Objects; + +import org.jspecify.annotations.NullMarked; + +@NullMarked +final class RightBridgeTupleLifecycle + implements TupleLifecycle { + + private final RightTupleLifecycle rightTupleLifecycle; + private boolean isActive; + + RightBridgeTupleLifecycle(RightTupleLifecycle rightTupleLifecycle) { + this.rightTupleLifecycle = Objects.requireNonNull(rightTupleLifecycle); + } + + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + this.isActive = upstreamCanProduceTuples; // We're just delegating. + rightTupleLifecycle.afterAllFactsInsertedRight(upstreamCanProduceTuples); + } + + @Override + public boolean isActive() { + return isActive && rightTupleLifecycle.isActive(); + } + + @Override + public void insert(Tuple_ tuple) { + rightTupleLifecycle.insertRight(tuple); + } + + @Override + public void update(Tuple_ tuple) { + rightTupleLifecycle.updateRight(tuple); + } + + @Override + public void retract(Tuple_ tuple) { + rightTupleLifecycle.retractRight(tuple); + } + + public RightTupleLifecycle rightTupleLifecycle() { + return rightTupleLifecycle; + } + + @Override + public String toString() { + return "right %s".formatted(rightTupleLifecycle); + } + + @Override + public boolean equals(Object o) { + return o instanceof RightBridgeTupleLifecycle other + && rightTupleLifecycle.equals(other.rightTupleLifecycle); + } + + @Override + public int hashCode() { + return Objects.hash(rightTupleLifecycle); + } + +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightTupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightTupleLifecycle.java index 3467a9ace77..ea9a79dcfad 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightTupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightTupleLifecycle.java @@ -2,6 +2,16 @@ public interface RightTupleLifecycle { + /** + * As defined by {@link ActivitySupport#afterAllFactsInserted}. + */ + void afterAllFactsInsertedRight(boolean upstreamCanProduceTuples); + + /** + * As defined by {@link ActivitySupport#isActive()}. + */ + boolean isActive(); + void insertRight(Tuple_ tuple); void updateRight(Tuple_ tuple); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightTupleLifecycleImpl.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightTupleLifecycleImpl.java deleted file mode 100644 index b207db9218f..00000000000 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/RightTupleLifecycleImpl.java +++ /dev/null @@ -1,33 +0,0 @@ -package ai.timefold.solver.core.impl.bavet.common.tuple; - -import java.util.Objects; - -record RightTupleLifecycleImpl(RightTupleLifecycle rightTupleLifecycle) - implements - TupleLifecycle { - - RightTupleLifecycleImpl { - Objects.requireNonNull(rightTupleLifecycle); - } - - @Override - public void insert(Tuple_ tuple) { - rightTupleLifecycle.insertRight(tuple); - } - - @Override - public void update(Tuple_ tuple) { - rightTupleLifecycle.updateRight(tuple); - } - - @Override - public void retract(Tuple_ tuple) { - rightTupleLifecycle.retractRight(tuple); - } - - @Override - public String toString() { - return "right " + rightTupleLifecycle; - } - -} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java index f08526a7e46..00299d4a1b7 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/common/tuple/TupleLifecycle.java @@ -16,14 +16,15 @@ import ai.timefold.solver.core.impl.bavet.common.StreamKind; import ai.timefold.solver.core.impl.score.stream.bavet.common.Scorer; -public interface TupleLifecycle { +public interface TupleLifecycle + extends ActivitySupport { static TupleLifecycle ofLeft(LeftTupleLifecycle leftTupleLifecycle) { - return new LeftTupleLifecycleImpl<>(leftTupleLifecycle); + return new LeftBridgeTupleLifecycle<>(leftTupleLifecycle); } static TupleLifecycle ofRight(RightTupleLifecycle rightTupleLifecycle) { - return new RightTupleLifecycleImpl<>(rightTupleLifecycle); + return new RightBridgeTupleLifecycle<>(rightTupleLifecycle); } @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -79,12 +80,12 @@ static TupleLifecycle(LeftTupleLifecycle lifecycle) - && lifecycle instanceof AbstractNode node) { + } else if (delegate instanceof LeftBridgeTupleLifecycle parent + && parent.leftTupleLifecycle() instanceof AbstractNode node) { streamKind = node.getStreamKind(); qualifier = Qualifier.LEFT_INPUT; - } else if (delegate instanceof RightTupleLifecycleImpl(RightTupleLifecycle tupleLifecycle) - && tupleLifecycle instanceof AbstractNode node) { + } else if (delegate instanceof RightBridgeTupleLifecycle parent + && parent.rightTupleLifecycle() instanceof AbstractNode node) { streamKind = node.getStreamKind(); qualifier = Qualifier.RIGHT_INPUT; } else if (delegate instanceof RecordingTupleLifecycle) { diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java index bf14e3112ad..a2843c24ea9 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/AbstractForEachUniNode.java @@ -4,8 +4,7 @@ import java.util.Map; import ai.timefold.solver.core.api.domain.solution.PlanningSolution; -import ai.timefold.solver.core.impl.bavet.common.AbstractNode; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.common.Propagator; import ai.timefold.solver.core.impl.bavet.common.StaticPropagationQueue; import ai.timefold.solver.core.impl.bavet.common.StreamKind; @@ -26,8 +25,7 @@ */ @NullMarked public abstract sealed class AbstractForEachUniNode - extends AbstractNode - implements BavetRootNode + extends AbstractRootNode permits ForEachFilteredUniNode, ForEachUnfilteredUniNode { private final Class forEachClass; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java index 9f7b7fdc3c1..30515598fce 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachFilteredUniNode.java @@ -3,7 +3,7 @@ import java.util.Objects; import java.util.function.Predicate; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; @@ -14,16 +14,34 @@ public final class ForEachFilteredUniNode extends AbstractForEachUniNode { + private final TupleLifecycle> nextNodesTupleLifecycle; private final Predicate filter; + private int tupleCountWithoutFiltering = 0; public ForEachFilteredUniNode(Class forEachClass, Predicate filter, TupleLifecycle> nextNodesTupleLifecycle, int outputStoreSize) { super(forEachClass, nextNodesTupleLifecycle, outputStoreSize); + this.nextNodesTupleLifecycle = Objects.requireNonNull(nextNodesTupleLifecycle); this.filter = Objects.requireNonNull(filter); } + @Override + public void afterAllFactsInserted(boolean unused) { + nextNodesTupleLifecycle.afterAllFactsInserted(tupleCountWithoutFiltering > 0); + } + + @Override + public boolean isActive() { + // The input may change during update, + // and therefore the filter may let things propagate which it previously did not. + // For this reason, this node must be considered active if it saw at least one input; + // only with zero tuples can it be considered inactive, as the filter has nothing to propagate. + return tupleCountWithoutFiltering > 0 && nextNodesTupleLifecycle.isActive(); + } + @Override public void insert(@Nullable A a) { + tupleCountWithoutFiltering++; // This is safe; each element is only inserted once, guaranteed by a fail-fast in the parent. if (!filter.test(a)) { // Skip inserting the tuple as it does not pass the filter. return; } @@ -44,6 +62,7 @@ public void update(@Nullable A a) { @Override public void retract(@Nullable A a) { + tupleCountWithoutFiltering--; var tuple = tupleMap.remove(a); if (tuple == null) { // The tuple was never inserted because it did not pass the filter. return; @@ -52,7 +71,7 @@ public void retract(@Nullable A a) { } @Override - public boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { + public boolean supports(AbstractRootNode.LifecycleOperation lifecycleOperation) { return true; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java index 88ffc03821e..fb40cb58be0 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUnfilteredUniNode.java @@ -1,6 +1,6 @@ package ai.timefold.solver.core.impl.bavet.uni; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; @@ -11,9 +11,24 @@ public final class ForEachUnfilteredUniNode extends AbstractForEachUniNode { + private final TupleLifecycle> nextNodesTupleLifecycle; + private boolean isActive; + public ForEachUnfilteredUniNode(Class forEachClass, TupleLifecycle> nextNodesTupleLifecycle, int outputStoreSize) { super(forEachClass, nextNodesTupleLifecycle, outputStoreSize); + this.nextNodesTupleLifecycle = nextNodesTupleLifecycle; + } + + @Override + public void afterAllFactsInserted(boolean unused) { + isActive = !tupleMap.isEmpty(); + nextNodesTupleLifecycle.afterAllFactsInserted(isActive); + } + + @Override + public boolean isActive() { + return isActive && nextNodesTupleLifecycle.isActive(); } @Override @@ -27,7 +42,7 @@ public void update(@Nullable A a) { } @Override - public boolean supports(BavetRootNode.LifecycleOperation lifecycleOperation) { + public boolean supports(AbstractRootNode.LifecycleOperation lifecycleOperation) { return true; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java index 1d0fc4cfc66..92dfee56180 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/bavet/uni/PrecomputeUniNode.java @@ -25,4 +25,5 @@ public PrecomputeUniNode(Supplier>> preco protected UniTuple remapTuple(UniTuple tuple) { return UniTuple.of(tuple.getA(), outputStoreSize); } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/GraphEdge.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/GraphEdge.java deleted file mode 100644 index 65ad50153d9..00000000000 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/GraphEdge.java +++ /dev/null @@ -1,20 +0,0 @@ -package ai.timefold.solver.core.impl.bavet.visual; - -import java.util.Objects; - -import ai.timefold.solver.core.impl.bavet.common.AbstractNode; - -record GraphEdge(AbstractNode from, AbstractNode to) { - - @Override - public boolean equals(Object o) { - if (!(o instanceof GraphEdge graphEdge)) - return false; - return Objects.equals(to.getId(), graphEdge.to.getId()) && Objects.equals(from.getId(), graphEdge.from.getId()); - } - - @Override - public int hashCode() { - return Objects.hash(from.getId(), to.getId()); - } -} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/GraphSink.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/GraphSink.java deleted file mode 100644 index fe79ef7a863..00000000000 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/GraphSink.java +++ /dev/null @@ -1,7 +0,0 @@ -package ai.timefold.solver.core.impl.bavet.visual; - -import ai.timefold.solver.core.impl.bavet.common.AbstractNode; -import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraint; - -record GraphSink(AbstractNode node, BavetConstraint constraint) { -} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/NodeGraph.java b/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/NodeGraph.java deleted file mode 100644 index 99691257457..00000000000 --- a/core/src/main/java/ai/timefold/solver/core/impl/bavet/visual/NodeGraph.java +++ /dev/null @@ -1,181 +0,0 @@ -package ai.timefold.solver.core.impl.bavet.visual; - -import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.TreeMap; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import ai.timefold.solver.core.api.score.stream.Constraint; -import ai.timefold.solver.core.impl.bavet.common.AbstractNode; -import ai.timefold.solver.core.impl.bavet.common.AbstractTwoInputNode; -import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; -import ai.timefold.solver.core.impl.bavet.common.BavetStream; -import ai.timefold.solver.core.impl.bavet.common.BavetStreamBinaryOperation; -import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; -import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraint; -import ai.timefold.solver.core.impl.score.stream.bavet.uni.BavetForEachUniConstraintStream; - -public record NodeGraph(Solution_ solution, List sources, - List edges, List> sinks) { - - @SuppressWarnings({ "unchecked", "rawtypes" }) - public static NodeGraph of(Solution_ solution, - List nodeList, Set constraintSet, Function nodeToStreamFunction, - Function streamToParentNodeFunction) { - var sourceList = new ArrayList(); - var edgeList = new ArrayList(); - for (var node : nodeList) { - var nodeCreator = nodeToStreamFunction.apply(node); - if (nodeCreator instanceof BavetForEachUniConstraintStream) { - sourceList.add(node); - } else if (nodeCreator instanceof BavetStreamBinaryOperation binaryOperation) { - var castBinaryOperation = (BavetStreamBinaryOperation) binaryOperation; - var leftParent = streamToParentNodeFunction.apply(castBinaryOperation.getLeftParent()); - edgeList.add(new GraphEdge(leftParent, node)); - var rightParent = streamToParentNodeFunction.apply(castBinaryOperation.getRightParent()); - edgeList.add(new GraphEdge(rightParent, node)); - } else { - var parent = streamToParentNodeFunction.apply(nodeCreator.getParent()); - edgeList.add(new GraphEdge(parent, node)); - } - } - var sinkList = new ArrayList>(); - for (var constraint : constraintSet) { - var castConstraint = (BavetConstraint) constraint; - var stream = (BavetAbstractConstraintStream) castConstraint.getScoringConstraintStream(); - var node = streamToParentNodeFunction.apply((Stream_) stream); - sinkList.add(new GraphSink<>(node, castConstraint)); - } - return new NodeGraph<>(solution, sourceList.stream().distinct().toList(), - edgeList.stream().distinct().toList(), - sinkList.stream().distinct().toList()); - } - - public String buildGraphvizDOT() { - var stringBuilder = new StringBuilder(); - var sourceStream = sources.stream(); - var edgeStream = edges.stream().flatMap(edge -> Stream.of(edge.from(), edge.to())); - // Gather all known nodes and order them by their ID. - var allNodes = Stream.concat(sourceStream, edgeStream) - .distinct() - .sorted(Comparator.comparingLong(AbstractNode::getId)) - .toList(); - stringBuilder.append( - " label=<Bavet Node Network for '%s'
%d constraints, %d nodes>;%n" - .formatted(solution.toString(), sinks.size(), allNodes.size())); - // Specify the edges. - for (var node : allNodes) { - for (var edge : edges) { - if (edge.from().equals(node)) { - var line = " %s -> %s;%n".formatted(nodeId(node), nodeId(edge.to())); - stringBuilder.append(line); - } - } - } - for (var i = 0; i < sinks.size(); i++) { - var sink = sinks.get(i); - var line = " %s -> %s;%n".formatted(nodeId(sink.node()), constraintId(i)); - stringBuilder.append(line); - } - // Specify visual attributes of the nodes. - for (var node : allNodes) { - var line = " %s %s;%n".formatted(nodeId(node), getMetadata(node)); - stringBuilder.append(line); - } - for (var i = 0; i < sinks.size(); i++) { - var sink = sinks.get(i); - var line = " %s %s;%n".formatted(constraintId(i), getMetadata(sink, solution)); - stringBuilder.append(line); - } - // Put nodes in the same layer to appear in the same rank. - var layerMap = new TreeMap>(); - for (var node : allNodes) { - var layer = node.getLayerIndex(); - layerMap.computeIfAbsent(layer, k -> new LinkedHashSet<>()).add(node); - } - for (var entry : layerMap.entrySet()) { - var line = entry.getValue().stream() - .map(NodeGraph::nodeId) - .collect(Collectors.joining("; ", " { rank=same; ", "; }" + System.lineSeparator())); - stringBuilder.append(line); - } - return """ - digraph { - rankdir=LR; - %s}""" - .formatted(stringBuilder.toString()); - } - - private static String getMetadata(AbstractNode node) { - var metadata = getBaseDOTProperties("lightgrey", false); - if (node instanceof AbstractForEachUniNode) { - metadata.put("style", "filled"); - metadata.put("fillcolor", "#3e00ff"); - metadata.put("fontcolor", "white"); - } else if (node instanceof AbstractTwoInputNode) { - // Nodes that join get a different color. - metadata.put("style", "filled"); - metadata.put("fillcolor", "#ff7700"); - metadata.put("fontcolor", "white"); - } - metadata.put("label", nodeLabel(node)); - return mergeMetadata(metadata); - } - - private static String mergeMetadata(Map metadata) { - return metadata.entrySet().stream() - .map(entry -> { - if (entry.getKey().equals("label")) { // Labels are HTML-formatted. - return "%s=<%s>".formatted(entry.getKey(), entry.getValue()); - } else { - return "%s=\"%s\"".formatted(entry.getKey(), entry.getValue()); - } - }) - .collect(Collectors.joining(", ", "[", "]")); - } - - private static String getMetadata(GraphSink sink, Solution_ solution) { - var constraint = sink.constraint(); - var metadata = getBaseDOTProperties("#3423a6", true); - metadata.put("label", "%s
(Weight: %s)" - .formatted(constraint.getConstraintRef().id(), constraint.extractConstraintWeight(solution))); - return mergeMetadata(metadata); - } - - private static Map getBaseDOTProperties(String fillcolor, boolean whiteText) { - var metadata = new HashMap(); - metadata.put("shape", "plaintext"); - metadata.put("pad", "0.2"); - metadata.put("style", "filled"); - metadata.put("fillcolor", fillcolor); - metadata.put("fontname", "Courier New"); - metadata.put("fontcolor", whiteText ? "white" : "black"); - return metadata; - } - - private static String nodeId(AbstractNode node) { - return "node" + node.getId(); - } - - private static String constraintId(int id) { - return "impact" + id; - } - - private static String nodeLabel(AbstractNode node) { - var className = node.getClass().getSimpleName() - .replace("Node", ""); - if (node instanceof AbstractForEachUniNode forEachNode) { - return "%s
(%s)".formatted(className, forEachNode.getForEachClass().getSimpleName()); - } else { - return "%s".formatted(className); - } - } - -} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/NeighborhoodsBavetNodeNetwork.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/NeighborhoodsBavetNodeNetwork.java new file mode 100644 index 00000000000..6424b55d97a --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/NeighborhoodsBavetNodeNetwork.java @@ -0,0 +1,46 @@ +package ai.timefold.solver.core.impl.neighborhood; + +import java.util.List; +import java.util.Map; + +import ai.timefold.solver.core.impl.bavet.AbstractBavetNodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; + +import org.jspecify.annotations.NullMarked; + +/** + * Represents Neighborhoods' network of nodes, specific to a particular session. + */ +@NullMarked +public final class NeighborhoodsBavetNodeNetwork extends AbstractBavetNodeNetwork { + + public static NeighborhoodsBavetNodeNetwork of(List nodeList, + Map, List>> declaredClassToNodeMap) { + var layeredNodes = AbstractBavetNodeNetwork.buildLayeredNodes(nodeList); + return new NeighborhoodsBavetNodeNetwork(declaredClassToNodeMap, layeredNodes); + } + + /** + * @param declaredClassToNodeMap starting nodes, one for each class used in the constraints; + * root nodes, layer index 0. + * @param layeredNodes nodes grouped first by their layer, then by their index within the layer; + * propagation needs to happen in this order. + */ + private NeighborhoodsBavetNodeNetwork(Map, List>> declaredClassToNodeMap, + AbstractNode[][] layeredNodes) { + super(declaredClassToNodeMap, layeredNodes, AbstractNode::getPropagator); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof NeighborhoodsBavetNodeNetwork)) + return false; + return super.equals(o); + } + + @Override + public int hashCode() { + return super.hashCode(); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSession.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSession.java index 104f3206a1a..8b4ad04971b 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSession.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSession.java @@ -5,8 +5,8 @@ import java.util.Objects; import ai.timefold.solver.core.impl.bavet.AbstractSession; -import ai.timefold.solver.core.impl.bavet.NodeNetwork; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.impl.neighborhood.NeighborhoodsBavetNodeNetwork; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractDataset; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractDatasetInstance; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractLeftDataset; @@ -17,14 +17,15 @@ import org.jspecify.annotations.NullMarked; @NullMarked -public final class DatasetSession extends AbstractSession { +public final class DatasetSession + extends AbstractSession { private final Map, AbstractDatasetInstance> leftDatasetInstanceMap = new IdentityHashMap<>(); private final Map, AbstractDatasetInstance> rightDatasetInstanceMap = new IdentityHashMap<>(); - DatasetSession(NodeNetwork nodeNetwork) { + DatasetSession(NeighborhoodsBavetNodeNetwork nodeNetwork) { super(nodeNetwork); } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java index c8aae79d4b2..a0d22ee9880 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/DatasetSessionFactory.java @@ -5,18 +5,15 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Set; -import java.util.function.Consumer; -import ai.timefold.solver.core.impl.bavet.NodeNetwork; -import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; +import ai.timefold.solver.core.impl.neighborhood.NeighborhoodsBavetNodeNetwork; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.AbstractEnumeratingStream; import ai.timefold.solver.core.impl.neighborhood.stream.enumerating.common.DataNodeBuildHelper; import ai.timefold.solver.core.impl.score.director.SessionContext; import org.jspecify.annotations.NullMarked; -import org.jspecify.annotations.Nullable; @NullMarked public final class DatasetSessionFactory { @@ -34,16 +31,16 @@ public DatasetSession buildSession(SessionContext context) dataset.collectActiveEnumeratingStreams(activeEnumeratingStreamSet); } var buildHelper = new DataNodeBuildHelper<>(context, activeEnumeratingStreamSet); - var session = new DatasetSession(buildNodeNetwork(activeEnumeratingStreamSet, buildHelper, null)); + var session = new DatasetSession(buildNodeNetwork(activeEnumeratingStreamSet, buildHelper)); for (var datasetInstance : buildHelper.getDatasetInstanceList()) { session.registerDatasetInstance(datasetInstance.getParent(), datasetInstance); } return session; } - private NodeNetwork buildNodeNetwork(Set> enumeratingStreamSet, - DataNodeBuildHelper buildHelper, @Nullable Consumer nodeNetworkVisualizationConsumer) { - var declaredClassToNodeMap = new LinkedHashMap, List>>(); + private NeighborhoodsBavetNodeNetwork buildNodeNetwork(Set> enumeratingStreamSet, + DataNodeBuildHelper buildHelper) { + var declaredClassToNodeMap = new LinkedHashMap, List>>(); var nodeList = buildHelper.buildNodeList(enumeratingStreamSet, buildHelper, AbstractEnumeratingStream::buildNode, node -> { if (!(node instanceof AbstractForEachUniNode forEachUniNode)) { @@ -60,11 +57,7 @@ private NodeNetwork buildNodeNetwork(Set> e } forEachUniNodeList.add(forEachUniNode); }); - if (nodeNetworkVisualizationConsumer != null) { - // TODO implement node network visualization - throw new UnsupportedOperationException("Not implemented yet"); - } - return AbstractNodeBuildHelper.buildNodeNetwork(nodeList, declaredClassToNodeMap, buildHelper); + return DataNodeBuildHelper.buildNodeNetwork(nodeList, declaredClassToNodeMap); } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/AbstractDatasetInstance.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/AbstractDatasetInstance.java index 2d1c2e1fce7..fb9ff7a0535 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/AbstractDatasetInstance.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/AbstractDatasetInstance.java @@ -13,12 +13,23 @@ public abstract class AbstractDatasetInstance private final AbstractDataset parent; protected final int entryStoreIndex; + private boolean upstreamCanProduceTuples; protected AbstractDatasetInstance(AbstractDataset parent, int entryStoreIndex) { this.parent = Objects.requireNonNull(parent); this.entryStoreIndex = entryStoreIndex; } + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + this.upstreamCanProduceTuples = upstreamCanProduceTuples; + } + + @Override + public boolean isActive() { + return upstreamCanProduceTuples; + } + public AbstractDataset getParent() { return parent; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/DataNodeBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/DataNodeBuildHelper.java index 978f14bb9df..cebe21d3d4c 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/DataNodeBuildHelper.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/common/DataNodeBuildHelper.java @@ -3,25 +3,30 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; +import ai.timefold.solver.core.impl.bavet.common.AbstractNode; import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.neighborhood.NeighborhoodsBavetNodeNetwork; import ai.timefold.solver.core.impl.score.director.SessionContext; import org.jspecify.annotations.NullMarked; @NullMarked -public final class DataNodeBuildHelper extends AbstractNodeBuildHelper> { +public final class DataNodeBuildHelper + extends AbstractNodeBuildHelper> { private final SessionContext sessionContext; private final List> datasetInstanceList = new ArrayList<>(); public DataNodeBuildHelper(SessionContext sessionContext, Set> activeStreamSet) { - super(activeStreamSet, null); + super(activeStreamSet); this.sessionContext = Objects.requireNonNull(sessionContext); } @@ -42,4 +47,10 @@ public SessionContext getSessionContext() { public List> getDatasetInstanceList() { return Collections.unmodifiableList(datasetInstanceList); } + + public static NeighborhoodsBavetNodeNetwork buildNodeNetwork(List nodeList, + Map, List>> declaredClassToNodeMap) { + return NeighborhoodsBavetNodeNetwork.of(nodeList, declaredClassToNodeMap); + } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java index a3f46ef2d60..ebb8cf59945 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/neighborhood/stream/enumerating/uni/AbstractForEachEnumeratingStream.java @@ -1,6 +1,6 @@ package ai.timefold.solver.core.impl.neighborhood.stream.enumerating.uni; -import static ai.timefold.solver.core.impl.bavet.common.BavetRootNode.LifecycleOperation; +import static ai.timefold.solver.core.impl.bavet.common.AbstractRootNode.LifecycleOperation; import java.util.Objects; import java.util.Set; diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirectorFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirectorFactory.java index bd6239823f4..f596146042e 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirectorFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/director/stream/BavetConstraintStreamScoreDirectorFactory.java @@ -2,7 +2,6 @@ import java.util.Arrays; import java.util.Objects; -import java.util.function.Consumer; import ai.timefold.solver.core.api.score.Score; import ai.timefold.solver.core.api.score.stream.ConstraintMetaModel; @@ -78,19 +77,10 @@ public BavetConstraintStreamScoreDirectorFactory(SolutionDescriptor s } public BavetConstraintSession newSession(Solution_ workingSolution, - ConsistencyTracker consistencyTracker, - ConstraintMatchPolicy constraintMatchPolicy, + ConsistencyTracker consistencyTracker, ConstraintMatchPolicy constraintMatchPolicy, boolean scoreDirectorDerived) { - return newSession(workingSolution, consistencyTracker, constraintMatchPolicy, scoreDirectorDerived, null); - } - - public BavetConstraintSession newSession(Solution_ workingSolution, - ConsistencyTracker consistencyTracker, - ConstraintMatchPolicy constraintMatchPolicy, - boolean scoreDirectorDerived, Consumer nodeNetworkVisualizationConsumer) { return constraintSessionFactory.buildSession(workingSolution, consistencyTracker, constraintMatchPolicy, - scoreDirectorDerived, - nodeNetworkVisualizationConsumer); + scoreDirectorDerived); } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java index 0482879418b..84a871593ec 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintFactory.java @@ -63,7 +63,7 @@ public > Stream_ share( * If a constraint already exists in this factory, it replaces it with the old copy. * {@link BavetAbstractConstraintStream} implement equals/hashcode ignoring child streams. *

- * {@link BavetConstraintSessionFactory#buildSession(Object, ConsistencyTracker, ConstraintMatchPolicy, boolean, Consumer)} + * {@link BavetConstraintSessionFactory#buildSession(Object, ConsistencyTracker, ConstraintMatchPolicy, boolean)} * needs * this to happen * for all streams. diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java index ce443f01200..8d68ea0e9dd 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSession.java @@ -5,7 +5,6 @@ import ai.timefold.solver.core.api.score.Score; import ai.timefold.solver.core.api.score.stream.ConstraintRef; import ai.timefold.solver.core.impl.bavet.AbstractSession; -import ai.timefold.solver.core.impl.bavet.NodeNetwork; import ai.timefold.solver.core.impl.bavet.common.PropagationQueue; import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; @@ -22,15 +21,16 @@ * * @param */ -public final class BavetConstraintSession> extends AbstractSession { +public final class BavetConstraintSession> + extends AbstractSession { private final AbstractScoreInliner scoreInliner; BavetConstraintSession(AbstractScoreInliner scoreInliner) { - this(scoreInliner, NodeNetwork.EMPTY); + this(scoreInliner, ConstraintStreamsBavetNodeNetwork.EMPTY); } - BavetConstraintSession(AbstractScoreInliner scoreInliner, NodeNetwork nodeNetwork) { + BavetConstraintSession(AbstractScoreInliner scoreInliner, ConstraintStreamsBavetNodeNetwork nodeNetwork) { super(nodeNetwork); this.scoreInliner = scoreInliner; } @@ -48,4 +48,8 @@ public Map> getConstraintMatchTotalM return scoreInliner.getConstraintMatchTotalMap(); } + public void summarizeProfileIfPresent() { + nodeNetwork.summarizeProfileIfPresent(); + } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java index 2ed21c9ba71..7629a033f5e 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/BavetConstraintSessionFactory.java @@ -5,22 +5,20 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.function.Consumer; +import java.util.function.Function; import java.util.stream.Collectors; import ai.timefold.solver.core.api.score.Score; import ai.timefold.solver.core.api.score.stream.Constraint; import ai.timefold.solver.core.api.score.stream.ConstraintMetaModel; import ai.timefold.solver.core.enterprise.TimefoldSolverEnterpriseService; -import ai.timefold.solver.core.impl.bavet.NodeNetwork; -import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.InnerConstraintProfiler; import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; -import ai.timefold.solver.core.impl.bavet.visual.NodeGraph; import ai.timefold.solver.core.impl.domain.solution.descriptor.SolutionDescriptor; import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; @@ -35,7 +33,7 @@ public final class BavetConstraintSessionFactory> { private static final Logger LOGGER = LoggerFactory.getLogger(BavetConstraintSessionFactory.class); - private static final Level CONSTRAINT_WEIGHT_LOGGING_LEVEL = Level.DEBUG; + public static final Level CONSTRAINT_WEIGHT_LOGGING_LEVEL = Level.DEBUG; private final SolutionDescriptor solutionDescriptor; private final ConstraintMetaModel constraintMetaModel; @@ -57,15 +55,12 @@ public BavetConstraintSessionFactory(SolutionDescriptor solutionDescr @SuppressWarnings("unchecked") public BavetConstraintSession buildSession(Solution_ workingSolution, - ConsistencyTracker consistencyTracker, - ConstraintMatchPolicy constraintMatchPolicy, - boolean scoreDirectorDerived, Consumer nodeNetworkVisualizationConsumer) { + ConsistencyTracker consistencyTracker, ConstraintMatchPolicy constraintMatchPolicy, + boolean scoreDirectorDerived) { var constraintWeightSupplier = solutionDescriptor.getConstraintWeightSupplier(); var constraints = constraintMetaModel.getConstraints(); if (constraintWeightSupplier != null) { // Fail fast on unknown constraints. - var knownConstraints = constraints.stream() - .map(Constraint::getConstraintRef) - .collect(Collectors.toSet()); + var knownConstraints = constraints.stream().map(Constraint::getConstraintRef).collect(Collectors.toSet()); constraintWeightSupplier.validate(workingSolution, knownConstraints); } var scoreDefinition = solutionDescriptor. getScoreDefinition(); @@ -76,8 +71,7 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, // Only log constraint weights if logging is enabled; otherwise we don't need to build the string. var constraintWeightLoggingEnabled = !scoreDirectorDerived && LOGGER.isEnabledForLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL); var constraintWeightString = constraintWeightLoggingEnabled - ? new StringBuilder("Constraint weights for solution (%s):%n" - .formatted(workingSolution)) + ? new StringBuilder("Constraint weights for solution (%s):%n".formatted(workingSolution)) : null; for (var constraint : constraints) { @@ -91,8 +85,8 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, constraintWeightString.append(" Constraint (%s) weight overridden to (%s) from (%s).%n" .formatted(constraintRef, constraintWeight, defaultConstraintWeight)); } else { - constraintWeightString.append(" Constraint (%s) weight set to (%s).%n" - .formatted(constraintRef, constraintWeight)); + constraintWeightString + .append(" Constraint (%s) weight set to (%s).%n".formatted(constraintRef, constraintWeight)); } } /* @@ -107,8 +101,7 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, * Filter out nodes that only lead to constraints with zero weight. * Note: Node sharing happens earlier, in BavetConstraintFactory#share(Stream_). */ - constraintWeightString.append(" Constraint (%s) disabled.%n" - .formatted(constraintRef)); + constraintWeightString.append(" Constraint (%s) disabled.%n".formatted(constraintRef)); } } } @@ -120,34 +113,27 @@ public BavetConstraintSession buildSession(Solution_ workingSolution, } if (constraintWeightLoggingEnabled) { - LOGGER.atLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL) - .log(constraintWeightString.toString().trim()); + LOGGER.atLevel(CONSTRAINT_WEIGHT_LOGGING_LEVEL).log(constraintWeightString.toString().trim()); } return new BavetConstraintSession<>(scoreInliner, - buildNodeNetwork(workingSolution, consistencyTracker, constraintStreamSet, scoreInliner, - constraintProfiler, - nodeNetworkVisualizationConsumer)); + buildNodeNetwork(workingSolution, consistencyTracker, constraintStreamSet, scoreInliner, constraintProfiler, + scoreDirectorDerived)); } - private static > NodeNetwork buildNodeNetwork(Solution_ workingSolution, + private ConstraintStreamsBavetNodeNetwork buildNodeNetwork(Solution_ workingSolution, ConsistencyTracker consistencyTracker, Set> constraintStreamSet, - AbstractScoreInliner scoreInliner, - InnerConstraintProfiler profiler, - Consumer nodeNetworkVisualizationConsumer) { - var buildHelper = - new ConstraintNodeBuildHelper<>(consistencyTracker, constraintStreamSet, scoreInliner, profiler); - var declaredClassToNodeMap = new LinkedHashMap, List>>(); - var nodeList = buildHelper.buildNodeList(constraintStreamSet, buildHelper, - BavetAbstractConstraintStream::buildNode, - node -> { - if (!(node instanceof BavetRootNode tupleSourceRoot)) { + AbstractScoreInliner scoreInliner, InnerConstraintProfiler profiler, boolean scoreDirectorDerived) { + var buildHelper = new ConstraintNodeBuildHelper<>(consistencyTracker, constraintStreamSet, scoreInliner, profiler); + var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var nodeList = + buildHelper.buildNodeList(constraintStreamSet, buildHelper, BavetAbstractConstraintStream::buildNode, node -> { + if (!(node instanceof AbstractRootNode tupleSourceRoot)) { return; } if (tupleSourceRoot instanceof AbstractForEachUniNode forEachUniNode) { var forEachClass = forEachUniNode.getForEachClass(); - var forEachUniNodeList = - declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2)); + var forEachUniNodeList = declaredClassToNodeMap.computeIfAbsent(forEachClass, k -> new ArrayList<>(2)); if (forEachUniNodeList.stream().filter(sourceNode -> sourceNode instanceof AbstractForEachUniNode) .count() == 3) { // Each class can have at most three forEach nodes: one including everything, one including consistent + null vars, the last consistent + no null vars. @@ -164,15 +150,20 @@ private static > NodeNetwork buildNodeNe } } }); - if (nodeNetworkVisualizationConsumer != null) { - var constraintSet = scoreInliner.getConstraints(); - var visualisation = NodeGraph - .of(workingSolution, nodeList, constraintSet, buildHelper::getNodeCreatingStream, - buildHelper::findParentNode) - .buildGraphvizDOT(); - nodeNetworkVisualizationConsumer.accept(visualisation); + var constraintToScorerMap = scoreInliner.getConstraints() + .stream() + .map(constraint -> (BavetConstraint) constraint) + .collect(Collectors.toMap(Function.identity(), + constraint -> buildHelper.getScorer(constraint.getScoringConstraintStream()), (a, b) -> a, + LinkedHashMap::new)); + + if (constraintProfiler != null) { + constraintProfiler.registerNodeGraph(workingSolution, nodeList, scoreInliner.getConstraints(), + buildHelper::getNodeCreatingStream, buildHelper::findParentNode); } - return AbstractNodeBuildHelper.buildNodeNetwork(nodeList, declaredClassToNodeMap, buildHelper); + + return buildHelper.buildNodeNetwork(nodeList, declaredClassToNodeMap, (Map) constraintToScorerMap, + scoreDirectorDerived); } } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/ConstraintStreamsBavetNodeNetwork.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/ConstraintStreamsBavetNodeNetwork.java new file mode 100644 index 00000000000..a8538064414 --- /dev/null +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/ConstraintStreamsBavetNodeNetwork.java @@ -0,0 +1,110 @@ +package ai.timefold.solver.core.impl.score.stream.bavet; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import ai.timefold.solver.core.impl.bavet.AbstractBavetNodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractNode; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; +import ai.timefold.solver.core.impl.bavet.common.InnerConstraintProfiler; +import ai.timefold.solver.core.impl.bavet.common.Propagator; +import ai.timefold.solver.core.impl.score.stream.bavet.common.Scorer; + +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.event.Level; + +/** + * Represents Constraint Streams' network of nodes, specific to a particular session. + * Nodes only used by disabled constraints have already been removed. + * + */ +@NullMarked +public final class ConstraintStreamsBavetNodeNetwork extends AbstractBavetNodeNetwork { + + private static final Logger LOGGER = LoggerFactory.getLogger(ConstraintStreamsBavetNodeNetwork.class); + + public static ConstraintStreamsBavetNodeNetwork of(List nodeList, + Map, List>> declaredClassToNodeMap, + Map, Scorer> constraintToScorerMap, Function propagatorFunction, + @Nullable InnerConstraintProfiler constraintProfiler, boolean scoreDirectorDerived) { + var layeredNodes = AbstractBavetNodeNetwork.buildLayeredNodes(nodeList); + return new ConstraintStreamsBavetNodeNetwork(declaredClassToNodeMap, constraintToScorerMap, layeredNodes, + propagatorFunction, constraintProfiler, scoreDirectorDerived); + } + + public static final ConstraintStreamsBavetNodeNetwork EMPTY = + new ConstraintStreamsBavetNodeNetwork(Map.of(), Map.of(), new AbstractNode[0][0], AbstractNode::getPropagator, null, + true); + + private final Map, Scorer> constraintToScorerMap; + private final @Nullable InnerConstraintProfiler constraintProfiler; + private final boolean scoreDirectorDerived; + private boolean printedInactiveConstraints = false; + + /** + * @param declaredClassToNodeMap starting nodes, one for each class used in the constraints; + * root nodes, layer index 0. + * @param layeredNodes nodes grouped first by their layer, then by their index within the layer; + * propagation needs to happen in this order. + * @param propagatorFunction function to get the propagator for a given node + */ + private ConstraintStreamsBavetNodeNetwork(Map, List>> declaredClassToNodeMap, + Map, Scorer> constraintToScorerMap, AbstractNode[][] layeredNodes, + Function propagatorFunction, @Nullable InnerConstraintProfiler constraintProfiler, + boolean scoreDirectorDerived) { + super(declaredClassToNodeMap, layeredNodes, propagatorFunction); + this.constraintToScorerMap = constraintToScorerMap; + this.constraintProfiler = constraintProfiler; + this.scoreDirectorDerived = scoreDirectorDerived; + } + + @Override + public void settle() { + super.settle(); + var loggingLevel = Level.DEBUG; // Makes sure the check and the logging always operate on the same level. + if (!LOGGER.isEnabledForLevel(loggingLevel)) { + return; + } + if (!scoreDirectorDerived && !printedInactiveConstraints && isActivationCheckComplete()) { + printedInactiveConstraints = true; + var substring = constraintToScorerMap.entrySet().stream() + .filter(entry -> !entry.getValue().isActive()) + .map(entry -> " Constraint (%s) with weight set to (%s).".formatted(entry.getKey().getConstraintRef(), + entry.getValue().getWeight())) + .collect(Collectors.joining(System.lineSeparator())); + if (substring.isEmpty()) { + return; + } + LOGGER.atLevel(loggingLevel).log(""" + Constraints deactivated due to being useless in the given working solution: + %s""".formatted(substring)); + } + } + + public @Nullable InnerConstraintProfiler getConstraintProfiler() { + return constraintProfiler; + } + + public void summarizeProfileIfPresent() { + if (constraintProfiler != null) { + constraintProfiler.summarize(); + } + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ConstraintStreamsBavetNodeNetwork)) + return false; + return super.equals(o); + } + + @Override + public int hashCode() { + return super.hashCode(); + } +} diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetPrecomputeBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetPrecomputeBuildHelper.java index afbe879c09e..3183c638d56 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetPrecomputeBuildHelper.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/BavetPrecomputeBuildHelper.java @@ -11,10 +11,9 @@ import ai.timefold.solver.core.api.score.stream.ConstraintFactory; import ai.timefold.solver.core.api.score.stream.ConstraintStream; import ai.timefold.solver.core.api.score.stream.PrecomputeFactory; -import ai.timefold.solver.core.impl.bavet.NodeNetwork; -import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.AbstractBavetNodeNetwork; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; -import ai.timefold.solver.core.impl.bavet.common.BavetRootNode; import ai.timefold.solver.core.impl.bavet.common.tuple.RecordingTupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; @@ -24,7 +23,7 @@ import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner; public final class BavetPrecomputeBuildHelper { - private final NodeNetwork nodeNetwork; + private final AbstractBavetNodeNetwork nodeNetwork; private final RecordingTupleLifecycle recordingTupleLifecycle; private final Class[] sourceClasses; private final Set> entityClassSet; @@ -66,11 +65,11 @@ public BavetPrecomputeBuildHelper( ConstraintMatchPolicy.DISABLED), null); - var declaredClassToNodeMap = new LinkedHashMap, List>>(); + var declaredClassToNodeMap = new LinkedHashMap, List>>(); var nodeList = buildHelper.buildNodeList(streamSet, buildHelper, BavetAbstractConstraintStream::buildNode, node -> { - if (!(node instanceof BavetRootNode sourceRootNode)) { + if (!(node instanceof AbstractRootNode sourceRootNode)) { return; } var nodeSourceClasses = sourceRootNode.getSourceClasses(); @@ -80,14 +79,14 @@ public BavetPrecomputeBuildHelper( } }); - this.nodeNetwork = AbstractNodeBuildHelper.buildNodeNetwork(nodeList, declaredClassToNodeMap, buildHelper); + this.nodeNetwork = buildHelper.buildPrecomputeNodeNetwork(nodeList, declaredClassToNodeMap); this.recordingTupleLifecycle = (RecordingTupleLifecycle) buildHelper .getAggregatedTupleLifecycle(List.of(recordingPrecomputeConstraintStream)); this.sourceClasses = declaredClassToNodeMap.keySet().toArray(new Class[0]); } - public NodeNetwork getNodeNetwork() { + public AbstractBavetNodeNetwork getNodeNetwork() { return nodeNetwork; } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/ConstraintNodeBuildHelper.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/ConstraintNodeBuildHelper.java index 08ac1b3e30b..0179f24fd5d 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/ConstraintNodeBuildHelper.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/ConstraintNodeBuildHelper.java @@ -1,16 +1,33 @@ package ai.timefold.solver.core.impl.score.stream.bavet.common; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Predicate; import ai.timefold.solver.core.api.score.Score; +import ai.timefold.solver.core.impl.bavet.common.AbstractNode; import ai.timefold.solver.core.impl.bavet.common.AbstractNodeBuildHelper; +import ai.timefold.solver.core.impl.bavet.common.AbstractRootNode; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.core.impl.bavet.common.BavetStream; +import ai.timefold.solver.core.impl.bavet.common.BavetStreamBinaryOperation; +import ai.timefold.solver.core.impl.bavet.common.ConstraintNodeProfileId; import ai.timefold.solver.core.impl.bavet.common.InnerConstraintProfiler; +import ai.timefold.solver.core.impl.bavet.common.ProfilingPropagator; +import ai.timefold.solver.core.impl.bavet.common.tuple.AggregatedTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.ProfilingTupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.domain.entity.descriptor.EntityDescriptor; import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraint; +import ai.timefold.solver.core.impl.score.stream.bavet.ConstraintStreamsBavetNodeNetwork; import ai.timefold.solver.core.impl.score.stream.common.ForEachFilteringCriteria; import ai.timefold.solver.core.impl.score.stream.common.inliner.AbstractScoreInliner; @@ -23,22 +40,86 @@ public final class ConstraintNodeBuildHelper scoreInliner; private final ConsistencyTracker consistencyTracker; + private final @Nullable InnerConstraintProfiler constraintProfiler; private final Map, Map>> entityDescriptorToForEachCriteriaToPredicateMap; + private final Map, List>> streamToProfileIdSets; + private final Map, Scorer> streamToScorers = new HashMap<>(); + + private long nextLifecycleProfilingId = 0; public ConstraintNodeBuildHelper(ConsistencyTracker consistencyTracker, - Set> activeStreamSet, - AbstractScoreInliner scoreInliner, + Set> activeStreamSet, AbstractScoreInliner scoreInliner, @Nullable InnerConstraintProfiler profiler) { - super(activeStreamSet, profiler); + super(activeStreamSet); this.consistencyTracker = consistencyTracker; this.scoreInliner = scoreInliner; + this.constraintProfiler = profiler; this.entityDescriptorToForEachCriteriaToPredicateMap = new HashMap<>(); + this.streamToProfileIdSets = HashMap.newHashMap(Math.max(16, activeStreamSet.size() / 2)); + } + + @Override + public void putInsertUpdateRetract(BavetAbstractConstraintStream stream, + TupleLifecycle tupleLifecycle) { + if (constraintProfiler != null) { + var out = TupleLifecycle.profiling(constraintProfiler, nextLifecycleProfilingId, stream, tupleLifecycle); + super.putInsertUpdateRetract(stream, out); + updateConstraintProfileIdSet(stream, out); + + if (tupleLifecycle instanceof Scorer scorer) { + // This is a scorer, so we can navigate up its parents + // to find all locations corresponding to this constraint + var queue = new ArrayDeque(); + var constraintSet = new LinkedHashSet(); + queue.add(stream); + while (!queue.isEmpty()) { + var currentStream = (BavetAbstractConstraintStream) queue.poll(); + var streamSets = streamToProfileIdSets.computeIfAbsent(currentStream, ignored -> new ArrayList<>()); + streamSets.add(constraintSet); + var lifecycle = getTupleLifecycle(currentStream); + if (lifecycle instanceof ProfilingTupleLifecycle profilingTupleLifecycle) { + constraintSet.add(profilingTupleLifecycle.profileId()); + } + if (currentStream instanceof BavetStreamBinaryOperation binaryOperation) { + queue.add(binaryOperation.getLeftParent()); + queue.add(binaryOperation.getRightParent()); + } else if (currentStream.getParent() != null) { + queue.add(currentStream.getParent()); + } + } + constraintProfiler.registerConstraint(scorer.getConstraintRef(), constraintSet); + } + nextLifecycleProfilingId++; + } else { + super.putInsertUpdateRetract(stream, tupleLifecycle); + } + if (tupleLifecycle instanceof Scorer scorer) { + streamToScorers.put((BavetScoringConstraintStream) stream, scorer); + } + } + + private void updateConstraintProfileIdSet(BavetAbstractConstraintStream stream, + TupleLifecycle tupleLifecycle) { + if (tupleLifecycle instanceof ProfilingTupleLifecycle profilingTupleLifecycle) { + var affectedSets = streamToProfileIdSets.getOrDefault(stream, Collections.emptyList()); + for (var affectedSet : affectedSets) { + affectedSet.add(profilingTupleLifecycle.profileId()); + } + } else if (tupleLifecycle instanceof AggregatedTupleLifecycle aggregated) { + for (var innerLifecycle : aggregated.downstream()) { + updateConstraintProfileIdSet(stream, innerLifecycle); + } + } } public AbstractScoreInliner getScoreInliner() { return scoreInliner; } + public Scorer getScorer(BavetScoringConstraintStream stream) { + return streamToScorers.get(stream); + } + @SuppressWarnings("unchecked") public @Nullable Predicate getForEachPredicateForEntityDescriptorAndCriteria( EntityDescriptor entityDescriptor, ForEachFilteringCriteria criteria) { @@ -47,4 +128,29 @@ public AbstractScoreInliner getScoreInliner() { return (Predicate) predicateMap.computeIfAbsent(criteria, ignored -> criteria.getFilterForEntityDescriptor(consistencyTracker, entityDescriptor)); } + + public ConstraintStreamsBavetNodeNetwork buildNodeNetwork(List nodeList, + Map, List>> declaredClassToNodeMap, + Map, Scorer> constraintToScorerMap, boolean scoreDirectorDerived) { + return ConstraintStreamsBavetNodeNetwork.of(nodeList, declaredClassToNodeMap, (Map) constraintToScorerMap, node -> { + if (constraintProfiler == null) { + return node.getPropagator(); + } + var profileKey = nextLifecycleProfilingId++; + var profileId = new ConstraintNodeProfileId(profileKey, node.getStreamKind(), + ConstraintNodeProfileId.Qualifier.NODE, node.getLocationSet()); + constraintProfiler.register(profileId); + var stream = getNodeCreator(node); + for (var affectedSet : streamToProfileIdSets.getOrDefault(stream, Collections.emptyList())) { + affectedSet.add(profileId); + } + return new ProfilingPropagator(constraintProfiler, profileId, node.getPropagator()); + }, constraintProfiler, scoreDirectorDerived); + } + + public ConstraintStreamsBavetNodeNetwork buildPrecomputeNodeNetwork(List nodeList, + Map, List>> declaredClassToNodeMap) { + return buildNodeNetwork(nodeList, declaredClassToNodeMap, Collections.emptyMap(), true); // Reduces logging. + } + } diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/Scorer.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/Scorer.java index f477897552b..df121f529b0 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/Scorer.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/common/Scorer.java @@ -2,6 +2,7 @@ import java.util.Objects; +import ai.timefold.solver.core.api.score.Score; import ai.timefold.solver.core.api.score.stream.ConstraintRef; import ai.timefold.solver.core.impl.bavet.common.tuple.Tuple; import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; @@ -16,6 +17,7 @@ public final class Scorer implements TupleLifecycle scoreImpacter; private final WeightedScoreImpacter weightedScoreImpacter; private final int inputStoreIndex; + private boolean isActive = true; public Scorer(ScoreImpacter scoreImpacter, WeightedScoreImpacter weightedScoreImpacter, int inputStoreIndex) { this.scoreImpacter = Objects.requireNonNull(scoreImpacter); @@ -23,6 +25,18 @@ public Scorer(ScoreImpacter scoreImpacter, WeightedScoreImpacter w this.inputStoreIndex = inputStoreIndex; } + @Override + public void afterAllFactsInserted(boolean upstreamCanProduceTuples) { + if (!upstreamCanProduceTuples) { + isActive = false; + } + } + + @Override + public boolean isActive() { + return isActive; + } + @Override public void insert(Tuple_ tuple) { if (tuple.getStore(inputStoreIndex) != null) { @@ -69,6 +83,10 @@ public ConstraintRef getConstraintRef() { return context.getConstraint().getConstraintRef(); } + public Score getWeight() { + return weightedScoreImpacter.getContext().getConstraintWeight(); + } + @Override public String toString() { var context = weightedScoreImpacter.getContext(); diff --git a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java index c15aba0344a..f0458c90a6b 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/uni/BavetForEachUniConstraintStream.java @@ -8,7 +8,6 @@ import ai.timefold.solver.core.api.score.Score; import ai.timefold.solver.core.impl.bavet.common.BavetAbstractConstraintStream; import ai.timefold.solver.core.impl.bavet.common.TupleSource; -import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; import ai.timefold.solver.core.impl.bavet.uni.ForEachFilteredUniNode; import ai.timefold.solver.core.impl.bavet.uni.ForEachUnfilteredUniNode; @@ -52,7 +51,7 @@ public void collectActiveConstraintStreams(Set> void buildNode(ConstraintNodeBuildHelper buildHelper) { - TupleLifecycle> tupleLifecycle = buildHelper.getAggregatedTupleLifecycle(childStreamList); + var tupleLifecycle = buildHelper.> getAggregatedTupleLifecycle(childStreamList); int outputStoreSize = buildHelper.extractTupleStoreSize(this); var filter = filterFunction != null ? filterFunction.apply(buildHelper) : null; var node = filter == null ? new ForEachUnfilteredUniNode<>(forEachClass, tupleLifecycle, outputStoreSize) @@ -66,16 +65,11 @@ public > void buildNode(ConstraintNodeBuildHelper that = (BavetForEachUniConstraintStream) other; - return Objects.equals(forEachClass, that.forEachClass) && Objects.equals(filterFunction, that.filterFunction) - && getRetrievalSemantics().equals(that.getRetrievalSemantics()); + public boolean equals(Object o) { + return o instanceof BavetForEachUniConstraintStream otherStream + && forEachClass.equals(otherStream.forEachClass) + && Objects.equals(filterFunction, otherStream.filterFunction) + && getRetrievalSemantics().equals(otherStream.getRetrievalSemantics()); } @Override diff --git a/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolutionManager.java b/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolutionManager.java index 34b3c0049f0..74d8e95bf02 100644 --- a/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolutionManager.java +++ b/core/src/main/java/ai/timefold/solver/core/impl/solver/DefaultSolutionManager.java @@ -16,13 +16,10 @@ import ai.timefold.solver.core.config.solver.EnvironmentMode; import ai.timefold.solver.core.config.solver.PreviewFeature; import ai.timefold.solver.core.enterprise.TimefoldSolverEnterpriseService; -import ai.timefold.solver.core.impl.domain.variable.declarative.ConsistencyTracker; import ai.timefold.solver.core.impl.domain.variable.listener.support.violation.VariableSnapshotTotal; import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; import ai.timefold.solver.core.impl.score.director.InnerScoreDirector; import ai.timefold.solver.core.impl.score.director.ScoreDirectorFactory; -import ai.timefold.solver.core.impl.score.director.stream.BavetConstraintStreamScoreDirectorFactory; -import ai.timefold.solver.core.impl.util.MutableReference; import ai.timefold.solver.core.preview.api.domain.solution.diff.PlanningSolutionDiff; import org.jspecify.annotations.NullMarked; @@ -143,42 +140,4 @@ public List> recommendAssignment ConstraintMatchPolicy.match(fetchPolicy), true); } - /** - * Generates a Bavet node network visualization for the given solution. - * It uses a Graphviz DOT language representation. - * The string returned by this method can be converted to an image using {@code dot}: - * - *

-     * $ dot -Tsvg input.dot > output.svg
-     * 
- * - * This assumes the string returned by this method is saved to a file named {@code input.dot}. - * - *

- * The node network itself is an internal implementation detail of Constraint Streams. - * Do not rely on any particular node network structure in production code, - * and do not micro-optimize your constraints to match the node network. - * Such optimizations are destined to become obsolete and possibly harmful as the node network evolves. - * - *

- * This method is only provided for debugging purposes - * and is deliberately not part of the public API. - * Its signature or behavior may change without notice, - * and it may be removed in future versions. - * - * @see Graphviz DOT language - * - * @param solution Will be used to read constraint weights, which determine the final node network. - * @return A string representing the node network in Graphviz DOT language. - */ - public @Nullable String visualizeNodeNetwork(Solution_ solution) { - if (scoreDirectorFactory instanceof BavetConstraintStreamScoreDirectorFactory bavetScoreDirectorFactory) { - var result = new MutableReference(null); - bavetScoreDirectorFactory.newSession(solution, new ConsistencyTracker<>(), ConstraintMatchPolicy.ENABLED, false, - result::setValue); - return result.getValue(); - } - throw new UnsupportedOperationException("Node network visualization is only supported when using Constraint Streams."); - } - } diff --git a/core/src/main/java/module-info.java b/core/src/main/java/module-info.java index 4c6d9e1cc37..6ab448f37a6 100644 --- a/core/src/main/java/module-info.java +++ b/core/src/main/java/module-info.java @@ -182,6 +182,7 @@ // enterprise-specific exports exports ai.timefold.solver.core.impl.bavet.common to ai.timefold.solver.enterprise.core; + exports ai.timefold.solver.core.impl.bavet.uni to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.constructionheuristic to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.constructionheuristic.decider to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.constructionheuristic.decider.forager to ai.timefold.solver.enterprise.core; @@ -202,6 +203,8 @@ exports ai.timefold.solver.core.impl.neighborhood to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.partitionedsearch to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.phase to ai.timefold.solver.enterprise.core; + exports ai.timefold.solver.core.impl.score.stream.bavet to ai.timefold.solver.enterprise.core; + exports ai.timefold.solver.core.impl.score.stream.bavet.uni to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.solver.random to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.solver.recaller to ai.timefold.solver.enterprise.core; exports ai.timefold.solver.core.impl.solver.event to ai.timefold.solver.enterprise.core; diff --git a/core/src/test/java/ai/timefold/solver/core/api/solver/SolutionManagerTest.java b/core/src/test/java/ai/timefold/solver/core/api/solver/SolutionManagerTest.java index 43bca61bdc7..d2405c4b11d 100644 --- a/core/src/test/java/ai/timefold/solver/core/api/solver/SolutionManagerTest.java +++ b/core/src/test/java/ai/timefold/solver/core/api/solver/SolutionManagerTest.java @@ -2,7 +2,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.SoftAssertions.assertSoftly; import java.util.Arrays; @@ -11,14 +10,8 @@ import ai.timefold.solver.core.api.score.HardSoftScore; import ai.timefold.solver.core.api.score.Score; -import ai.timefold.solver.core.api.score.SimpleScore; import ai.timefold.solver.core.config.score.director.ScoreDirectorFactoryConfig; import ai.timefold.solver.core.config.solver.SolverConfig; -import ai.timefold.solver.core.impl.solver.DefaultSolutionManager; -import ai.timefold.solver.core.testdomain.TestdataConstraintProvider; -import ai.timefold.solver.core.testdomain.TestdataEasyScoreCalculator; -import ai.timefold.solver.core.testdomain.TestdataEntity; -import ai.timefold.solver.core.testdomain.TestdataSolution; import ai.timefold.solver.core.testdomain.list.shadowhistory.TestdataListEntityWithShadowHistory; import ai.timefold.solver.core.testdomain.list.shadowhistory.TestdataListSolutionWithShadowHistory; import ai.timefold.solver.core.testdomain.list.shadowhistory.TestdataListValueWithShadowHistory; @@ -57,17 +50,6 @@ public class SolutionManagerTest { .withScoreDirectorFactory( new ScoreDirectorFactoryConfig() .withConstraintProviderClass(TestdataListWithShadowHistoryConstraintProvider.class))); - public static final SolverFactory SOLVER_FACTORY_WITH_CS = SolverFactory.create( - new SolverConfig() - .withSolutionClass(TestdataSolution.class) - .withEntityClasses(TestdataEntity.class) - .withConstraintProviderClass(TestdataConstraintProvider.class)); - public static final SolverFactory SOLVER_FACTORY_EASY = SolverFactory.create( - new SolverConfig() - .withSolutionClass(TestdataSolution.class) - .withEntityClasses(TestdataEntity.class) - .withScoreDirectorFactory( - new ScoreDirectorFactoryConfig().withEasyScoreCalculatorClass(TestdataEasyScoreCalculator.class))); @ParameterizedTest @EnumSource(SolutionManagerSource.class) @@ -353,38 +335,6 @@ void updateOnlyScoreFailsIfListVariableInconsistent(SolutionManagerSource soluti " has an oldInverseEntity (e2) which is not that entity."); } - @SuppressWarnings("unchecked") - @ParameterizedTest - @EnumSource(SolutionManagerSource.class) - void visualizeNodeNetwork(SolutionManagerSource solutionManagerSource) { - var solution = new TestdataSolution(); - var solutionManager = (DefaultSolutionManager) solutionManagerSource - .createSolutionManager(SOLVER_FACTORY_WITH_CS); - var result = solutionManager.visualizeNodeNetwork(solution); - assertThat(result).isEqualToIgnoringWhitespace( - """ - digraph { - rankdir=LR; - label=<Bavet Node Network for 'null'
1 constraints, 1 nodes>; - node0 -> impact0; - node0 [pad="0.2", fillcolor="#3e00ff", shape="plaintext", fontcolor="white", style="filled", label=<ForEachFilteredUni
(TestdataEntity)>, fontname="Courier New"]; - impact0 [pad="0.2", fillcolor="#3423a6", shape="plaintext", fontcolor="white", style="filled", label=<Always penalize
(Weight: -1)>, fontname="Courier New"]; - { rank=same; node0; } - }"""); - } - - @SuppressWarnings("unchecked") - @ParameterizedTest - @EnumSource(SolutionManagerSource.class) - void visualizeNodeNetworkNoBavet(SolutionManagerSource solutionManagerSource) { - var solution = new TestdataSolution(); - var solutionManager = (DefaultSolutionManager) solutionManagerSource - .createSolutionManager(SOLVER_FACTORY_EASY); - assertThatThrownBy(() -> solutionManager.visualizeNodeNetwork(solution)) - .isInstanceOf(UnsupportedOperationException.class) - .hasMessageContaining("Constraint Streams"); - } - @SuppressWarnings({ "unchecked", "rawtypes" }) public enum SolutionManagerSource { diff --git a/core/src/test/java/ai/timefold/solver/core/impl/bavet/BavetNodeDeactivationTest.java b/core/src/test/java/ai/timefold/solver/core/impl/bavet/BavetNodeDeactivationTest.java new file mode 100644 index 00000000000..380f6fac266 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/bavet/BavetNodeDeactivationTest.java @@ -0,0 +1,270 @@ +package ai.timefold.solver.core.impl.bavet; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import ai.timefold.solver.core.api.score.SimpleScore; +import ai.timefold.solver.core.api.score.stream.Constraint; +import ai.timefold.solver.core.api.score.stream.ConstraintCollectors; +import ai.timefold.solver.core.api.score.stream.ConstraintProvider; +import ai.timefold.solver.core.impl.bavet.common.AbstractNode; +import ai.timefold.solver.core.impl.bavet.common.StreamKind; +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; +import ai.timefold.solver.core.impl.bavet.uni.AbstractForEachUniNode; +import ai.timefold.solver.core.impl.score.constraint.ConstraintMatchPolicy; +import ai.timefold.solver.core.impl.score.director.InnerScoreDirector; +import ai.timefold.solver.core.impl.score.director.stream.BavetConstraintStreamScoreDirector; +import ai.timefold.solver.core.impl.score.stream.bavet.BavetConstraintStreamImplSupport; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishEntity; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishExtra; +import ai.timefold.solver.core.testdomain.score.lavish.TestdataLavishSolution; + +import org.junit.jupiter.api.Test; + +/** + * See {@link ActivitySupport}. + */ +class BavetNodeDeactivationTest { + + /** + * Builds the network for the given constraints + solution; setWorkingSolution settles it (read-only; no scoring). + *

+ * Typed as {@link AbstractBavetNodeNetwork}, not the concrete ConstraintStreamsBavetNodeNetwork: the seam methods + * getActiveNodes()/getNodes() are package-private in AbstractBavetNodeNetwork, and a subclass in a different + * package does not inherit package-private members (JLS 8.4.8), so a call on the concrete type would not compile. + * This test lives in package ai.timefold.solver.core.impl.bavet, so a base-typed reference resolves them. + *

+ * Scoring is performed by a Scorer (a TupleLifecycle) held outside the node network, so no node carries + * StreamKind.SCORING; the activity of a constraint's terminal node already reflects whether its scorer stayed + * active (terminal isActive() == canProduceTuples() && scorer.isActive()). + */ + private static AbstractBavetNodeNetwork settle(ConstraintProvider provider, TestdataLavishSolution solution) { + var implSupport = new BavetConstraintStreamImplSupport(ConstraintMatchPolicy.DISABLED); + InnerScoreDirector scoreDirector = + implSupport.buildScoreDirector(TestdataLavishSolution.buildSolutionDescriptor(), provider); + // setWorkingSolution -> afterSetWorkingSolution -> session.settle() (guarded); network is settled on return. + scoreDirector.setWorkingSolution(solution); + var director = (BavetConstraintStreamScoreDirector) scoreDirector; + var session = director.getSession(); + assertThat(session).isNotNull(); + return session.getNodeNetwork(); + } + + private static boolean active(AbstractBavetNodeNetwork net, StreamKind kind) { + return net.getActiveNodes() + .stream() + .anyMatch(n -> n.getStreamKind() == kind); + } + + private static boolean present(AbstractBavetNodeNetwork net, StreamKind kind) { + return net.getNodes() + .stream() + .anyMatch(n -> n.getStreamKind() == kind); + } + + private static boolean activeForEach(AbstractBavetNodeNetwork net, Class forEachClass) { + return matchesForEach(net.getActiveNodes(), forEachClass); + } + + private static long forEachCount(Collection nodes, Class forEachClass) { + return nodes.stream() + .filter(n -> n instanceof AbstractForEachUniNode fe && fe.getForEachClass().equals(forEachClass)) + .count(); + } + + private static boolean matchesForEach(Collection nodes, Class forEachClass) { + return forEachCount(nodes, forEachClass) > 0; + } + + private static TestdataLavishSolution solutionWithoutExtras() { + return TestdataLavishSolution.generateSolution(); // entityList populated; extraList empty by default + } + + private static TestdataLavishSolution solutionWithExtras() { + var solution = TestdataLavishSolution.generateSolution(); + var extras = new ArrayList<>(solution.getExtraList()); + extras.add(new TestdataLavishExtra("extra1")); + solution.setExtraList(extras); + return solution; + } + + @Test + void emptyJoinSideDeactivatesChain() { + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishExtra.class) + .penalize(SimpleScore.ONE) + .asConstraint("entityJoinExtra") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(active(net, StreamKind.JOIN)).isFalse(); + assertThat(present(net, StreamKind.JOIN)).isTrue(); // node exists but is inactive + assertThat(activeForEach(net, TestdataLavishExtra.class)).isFalse(); + assertThat(activeForEach(net, TestdataLavishEntity.class)).isFalse(); // feeds only the dead join + } + + @Test + void nonEmptyJoinSideKeepsChainActive() { + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishExtra.class) + .penalize(SimpleScore.ONE) + .asConstraint("entityJoinExtra") + }; + var net = settle(provider, solutionWithExtras()); + + assertThat(active(net, StreamKind.JOIN)).isTrue(); + assertThat(activeForEach(net, TestdataLavishExtra.class)).isTrue(); + assertThat(activeForEach(net, TestdataLavishEntity.class)).isTrue(); + } + + @Test + void ifNotExistsWithEmptyRightStaysActive() { + // Right is irrelevant for ifNotExists: an empty right does not deactivate it. + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishEntity.class) + .ifNotExists(TestdataLavishExtra.class) + .penalize(SimpleScore.ONE) + .asConstraint("entityIfNotExistsExtra") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(active(net, StreamKind.IF_EXISTS)).isTrue(); + assertThat(activeForEach(net, TestdataLavishEntity.class)).isTrue(); + assertThat(activeForEach(net, TestdataLavishExtra.class)).isFalse(); // empty right root + } + + @Test + void ifExistsWithEmptyRightDeactivates() { + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishEntity.class) + .ifExists(TestdataLavishExtra.class) + .penalize(SimpleScore.ONE) + .asConstraint("entityIfExistsExtra") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(active(net, StreamKind.IF_EXISTS)).isFalse(); + assertThat(present(net, StreamKind.IF_EXISTS)).isTrue(); + assertThat(activeForEach(net, TestdataLavishEntity.class)).isFalse(); + } + + @Test + void concatWithOneEmptySideStaysActive() { + // concat requires both sides to share arity and element type, so both forEach streams are mapped to Object. + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishExtra.class) + .map(extra -> (Object) extra) + .concat(factory.forEach(TestdataLavishEntity.class) + .map(entity -> (Object) entity)) + .penalize(SimpleScore.ONE) + .asConstraint("concatExtraEntity") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(active(net, StreamKind.CONCAT)).isTrue(); + assertThat(activeForEach(net, TestdataLavishExtra.class)).isFalse(); // empty branch + assertThat(activeForEach(net, TestdataLavishEntity.class)).isTrue(); // populated branch + } + + @Test + void groupByOverEmptyClassDeactivates() { + // Locks the assumption that an empty groupBy emits nothing. + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishExtra.class) + .groupBy(ConstraintCollectors.count()) + .penalize(SimpleScore.ONE) + .asConstraint("countExtras") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(active(net, StreamKind.GROUP_BY)).isFalse(); + assertThat(present(net, StreamKind.GROUP_BY)).isTrue(); + assertThat(activeForEach(net, TestdataLavishExtra.class)).isFalse(); + } + + @Test + void chainOverEmptyClassFullyDeactivates() { + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishExtra.class) + .filter(extra -> true) + .map(extra -> extra) + .penalize(SimpleScore.ONE) + .asConstraint("chainOverExtras") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(net.getActiveNodes()).isEmpty(); + assertThat(activeForEach(net, TestdataLavishExtra.class)).isFalse(); + assertThat(active(net, StreamKind.MAP)).isFalse(); + // filter() compiles to a ConditionalTupleLifecycle WRAPPER, not a node, so there is no FILTER node. + assertThat(present(net, StreamKind.FILTER)).isFalse(); + } + + @Test + void flattenOverEmptyClassDeactivates() { + // FlattenLastUniNode inherits the single-input default rule (upstreamCanProduceTuples && downstream). + ConstraintProvider provider = factory -> new Constraint[] { + factory.forEach(TestdataLavishExtra.class) + .flattenLast(extra -> List.of(extra)) + .penalize(SimpleScore.ONE) + .asConstraint("flattenExtras") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(active(net, StreamKind.FLATTEN)).isFalse(); + assertThat(present(net, StreamKind.FLATTEN)).isTrue(); + assertThat(activeForEach(net, TestdataLavishExtra.class)).isFalse(); + } + + @Test + void sharedForEachStaysActiveWhenOneBranchIsAlive() { + ConstraintProvider provider = factory -> { + var base = factory.forEach(TestdataLavishEntity.class); // shared by both constraints + return new Constraint[] { + base.penalize(SimpleScore.ONE) + .asConstraint("aliveBranch"), + base.join(TestdataLavishExtra.class) + .penalize(SimpleScore.ONE) + .asConstraint("deadBranch") + }; + }; + var net = settle(provider, solutionWithoutExtras()); + + // The shared Entity forEach is built exactly once and stays active via the alive branch, + // even though the join branch is dead (empty Extra). + assertThat(forEachCount(net.getNodes(), TestdataLavishEntity.class)).isEqualTo(1L); + assertThat(activeForEach(net, TestdataLavishEntity.class)).isTrue(); + assertThat(active(net, StreamKind.JOIN)).isFalse(); + assertThat(activeForEach(net, TestdataLavishExtra.class)).isFalse(); + } + + @Test + void precomputeOverEmptySourceDeactivates() { + ConstraintProvider provider = factory -> new Constraint[] { + factory.precompute(pf -> pf.forEachUnfiltered(TestdataLavishExtra.class)) + .penalize(SimpleScore.ONE) + .asConstraint("precomputeExtras") + }; + var net = settle(provider, solutionWithoutExtras()); + + assertThat(active(net, StreamKind.PRECOMPUTE)).isFalse(); + assertThat(present(net, StreamKind.PRECOMPUTE)).isTrue(); + } + + @Test + void precomputeOverPopulatedSourceStaysActive() { + ConstraintProvider provider = factory -> new Constraint[] { + factory.precompute(pf -> pf.forEachUnfiltered(TestdataLavishExtra.class)) + .penalize(SimpleScore.ONE) + .asConstraint("precomputeExtras") + }; + var net = settle(provider, solutionWithExtras()); + + assertThat(active(net, StreamKind.PRECOMPUTE)).isTrue(); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/bavet/bi/JoinBiNodeActivityTest.java b/core/src/test/java/ai/timefold/solver/core/impl/bavet/bi/JoinBiNodeActivityTest.java new file mode 100644 index 00000000000..a5c3e960a4c --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/bavet/bi/JoinBiNodeActivityTest.java @@ -0,0 +1,71 @@ +package ai.timefold.solver.core.impl.bavet.bi; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; +import ai.timefold.solver.core.impl.bavet.common.tuple.BiTuple; +import ai.timefold.solver.core.impl.bavet.common.tuple.InOutTupleStorePositionTracker; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; + +import org.junit.jupiter.api.Test; + +/** + * See {@link ActivitySupport}. + */ +class JoinBiNodeActivityTest { + + @SuppressWarnings("unchecked") + private static TupleLifecycle> mockDownstream(boolean active) { + TupleLifecycle> downstream = mock(TupleLifecycle.class); + when(downstream.isActive()).thenReturn(active); + return downstream; + } + + private static UnindexedJoinBiNode node(TupleLifecycle> downstream) { + var tracker = mock(InOutTupleStorePositionTracker.class); + return new UnindexedJoinBiNode<>(downstream, null, tracker); + } + + @Test + void activeWhenBothSidesProduce() { + var downstream = mockDownstream(true); + var node = node(downstream); + node.afterAllFactsInsertedLeft(true); + node.afterAllFactsInsertedRight(true); + assertThat(node.isActive()).isTrue(); + verify(downstream, times(1)).afterAllFactsInserted(true); + } + + @Test + void inactiveWhenRightCannotProduce() { + var downstream = mockDownstream(true); + var node = node(downstream); + node.afterAllFactsInsertedLeft(true); + node.afterAllFactsInsertedRight(false); + assertThat(node.isActive()).isFalse(); + verify(downstream, times(1)).afterAllFactsInserted(false); + } + + @Test + void inactiveWhenLeftCannotProduce() { + var downstream = mockDownstream(true); + var node = node(downstream); + node.afterAllFactsInsertedLeft(false); + node.afterAllFactsInsertedRight(true); + assertThat(node.isActive()).isFalse(); + verify(downstream, times(1)).afterAllFactsInserted(false); + } + + @Test + void inactiveWhenDownstreamInactive() { + var downstream = mockDownstream(false); + var node = node(downstream); + node.afterAllFactsInsertedLeft(true); + node.afterAllFactsInsertedRight(true); + assertThat(node.isActive()).isFalse(); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/bavet/common/SingleInputNodeActivityTest.java b/core/src/test/java/ai/timefold/solver/core/impl/bavet/common/SingleInputNodeActivityTest.java new file mode 100644 index 00000000000..d67e25c6327 --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/bavet/common/SingleInputNodeActivityTest.java @@ -0,0 +1,68 @@ +package ai.timefold.solver.core.impl.bavet.common; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.function.Function; + +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.impl.bavet.uni.MapUniToUniNode; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * See {@link ActivitySupport}. + */ +class SingleInputNodeActivityTest { + + private static MapUniToUniNode node(TupleLifecycle> downstream) { + Function identity = s -> s; + return new MapUniToUniNode<>(0, identity, downstream, 1); + } + + @SuppressWarnings("unchecked") + private static TupleLifecycle> mockDownstream(boolean active) { + TupleLifecycle> downstream = mock(TupleLifecycle.class); + when(downstream.isActive()).thenReturn(active); + return downstream; + } + + @Test + void activeWhenUpstreamAndDownstreamActive() { + var downstream = mockDownstream(true); + var node = node(downstream); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isTrue(); + } + + @Test + void inactiveWhenUpstreamCannotProduce() { + var downstream = mockDownstream(true); + var node = node(downstream); + node.afterAllFactsInserted(false); + assertThat(node.isActive()).isFalse(); + } + + @Test + void inactiveWhenDownstreamInactive() { + var downstream = mockDownstream(false); + var node = node(downstream); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isFalse(); + } + + @ParameterizedTest + @ValueSource(booleans = { true, false }) + void forwardsUpstreamCapabilityDownstream(boolean upstreamCanProduce) { + var downstream = mockDownstream(true); + var node = node(downstream); + node.afterAllFactsInserted(upstreamCanProduce); + verify(downstream).afterAllFactsInserted(upstreamCanProduce); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/ConcatUniUniNodeActivityTest.java b/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/ConcatUniUniNodeActivityTest.java new file mode 100644 index 00000000000..c17f312710d --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/ConcatUniUniNodeActivityTest.java @@ -0,0 +1,68 @@ +package ai.timefold.solver.core.impl.bavet.uni; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; + +import org.junit.jupiter.api.Test; + +/** + * See {@link ActivitySupport}. + */ +class ConcatUniUniNodeActivityTest { + + @SuppressWarnings("unchecked") + private static TupleLifecycle> mockDownstream(boolean active) { + TupleLifecycle> downstream = mock(TupleLifecycle.class); + when(downstream.isActive()).thenReturn(active); + return downstream; + } + + private static ConcatUniUniNode node(TupleLifecycle> downstream) { + return new ConcatUniUniNode<>(downstream, 0, 1, 2); + } + + @Test + void activeWhenBothSidesProduce() { + var node = node(mockDownstream(true)); + node.afterAllFactsInsertedLeft(true); + node.afterAllFactsInsertedRight(true); + assertThat(node.isActive()).isTrue(); + } + + @Test + void activeWhenOnlyLeftProduces() { + var node = node(mockDownstream(true)); + node.afterAllFactsInsertedLeft(true); + node.afterAllFactsInsertedRight(false); + assertThat(node.isActive()).isTrue(); + } + + @Test + void activeWhenOnlyRightProduces() { + var node = node(mockDownstream(true)); + node.afterAllFactsInsertedLeft(false); + node.afterAllFactsInsertedRight(true); + assertThat(node.isActive()).isTrue(); + } + + @Test + void inactiveWhenNeitherProduces() { + var node = node(mockDownstream(true)); + node.afterAllFactsInsertedLeft(false); + node.afterAllFactsInsertedRight(false); + assertThat(node.isActive()).isFalse(); + } + + @Test + void inactiveWhenDownstreamInactive() { + var node = node(mockDownstream(false)); + node.afterAllFactsInsertedLeft(true); + node.afterAllFactsInsertedRight(true); + assertThat(node.isActive()).isFalse(); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUniNodeActivityTest.java b/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUniNodeActivityTest.java new file mode 100644 index 00000000000..5017b89c41b --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/ForEachUniNodeActivityTest.java @@ -0,0 +1,76 @@ +package ai.timefold.solver.core.impl.bavet.uni; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; + +import org.junit.jupiter.api.Test; + +/** + * See {@link ActivitySupport}. + */ +class ForEachUniNodeActivityTest { + + @SuppressWarnings("unchecked") + private static TupleLifecycle> mockDownstream(boolean active) { + TupleLifecycle> downstream = mock(TupleLifecycle.class); + when(downstream.isActive()).thenReturn(active); + return downstream; + } + + @Test + void unfilteredInactiveWhenNoFacts() { + var downstream = mockDownstream(true); + var node = new ForEachUnfilteredUniNode<>(String.class, downstream, 1); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isFalse(); + } + + @Test + void unfilteredActiveWhenFactsExist() { + var downstream = mockDownstream(true); + var node = new ForEachUnfilteredUniNode<>(String.class, downstream, 1); + node.insert("a"); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isTrue(); + } + + @Test + void unfilteredInactiveWhenDownstreamInactive() { + var downstream = mockDownstream(false); + var node = new ForEachUnfilteredUniNode<>(String.class, downstream, 1); + node.insert("a"); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isFalse(); + } + + @Test + void filteredInactiveWhenNoFacts() { + var downstream = mockDownstream(true); + var node = new ForEachFilteredUniNode<>(String.class, s -> true, downstream, 1); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isFalse(); + } + + @Test + void filteredActiveWhenFactInsertedEvenIfFilteredOut() { + var downstream = mockDownstream(true); + var node = new ForEachFilteredUniNode<>(String.class, s -> false, downstream, 1); + node.insert("a"); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isTrue(); + } + + @Test + void filteredInactiveWhenDownstreamInactive() { + var downstream = mockDownstream(false); + var node = new ForEachFilteredUniNode<>(String.class, s -> true, downstream, 1); + node.insert("a"); + node.afterAllFactsInserted(true); + assertThat(node.isActive()).isFalse(); + } +} diff --git a/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/IfExistsUniNodeActivityTest.java b/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/IfExistsUniNodeActivityTest.java new file mode 100644 index 00000000000..8a3a5fa643a --- /dev/null +++ b/core/src/test/java/ai/timefold/solver/core/impl/bavet/uni/IfExistsUniNodeActivityTest.java @@ -0,0 +1,85 @@ +package ai.timefold.solver.core.impl.bavet.uni; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.timefold.solver.core.impl.bavet.common.tuple.ActivitySupport; +import ai.timefold.solver.core.impl.bavet.common.tuple.InTupleStorePositionTracker; +import ai.timefold.solver.core.impl.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.impl.bavet.common.tuple.UniTuple; + +import org.junit.jupiter.api.Test; + +/** + * See {@link ActivitySupport}. + */ +class IfExistsUniNodeActivityTest { + + @SuppressWarnings("unchecked") + private static TupleLifecycle> mockDownstream(boolean active) { + TupleLifecycle> downstream = mock(TupleLifecycle.class); + when(downstream.isActive()).thenReturn(active); + return downstream; + } + + private static UnindexedIfExistsUniNode node(boolean shouldExist, + TupleLifecycle> downstream) { + var tracker = mock(InTupleStorePositionTracker.class); + return new UnindexedIfExistsUniNode<>(shouldExist, downstream, tracker); + } + + private static void initBothSides(UnindexedIfExistsUniNode node, boolean left, boolean right) { + node.afterAllFactsInsertedLeft(left); + node.afterAllFactsInsertedRight(right); + } + + @Test + void ifExistsActiveWhenBothProduce() { + var node = node(true, mockDownstream(true)); + initBothSides(node, true, true); + assertThat(node.isActive()).isTrue(); + } + + @Test + void ifExistsInactiveWhenRightEmpty() { + var node = node(true, mockDownstream(true)); + initBothSides(node, true, false); + assertThat(node.isActive()).isFalse(); + } + + @Test + void ifExistsInactiveWhenLeftEmpty() { + var node = node(true, mockDownstream(true)); + initBothSides(node, false, true); + assertThat(node.isActive()).isFalse(); + } + + @Test + void ifNotExistsActiveWhenRightEmpty() { + var node = node(false, mockDownstream(true)); + initBothSides(node, true, false); + assertThat(node.isActive()).isTrue(); + } + + @Test + void ifNotExistsActiveWhenRightProduces() { + var node = node(false, mockDownstream(true)); + initBothSides(node, true, true); + assertThat(node.isActive()).isTrue(); + } + + @Test + void ifNotExistsInactiveWhenLeftEmpty() { + var node = node(false, mockDownstream(true)); + initBothSides(node, false, false); + assertThat(node.isActive()).isFalse(); + } + + @Test + void ifNotExistsInactiveWhenDownstreamInactive() { + var node = node(false, mockDownstream(false)); + initBothSides(node, true, false); + assertThat(node.isActive()).isFalse(); + } +} diff --git a/docs/src/modules/ROOT/pages/constraints-and-score/performance.adoc b/docs/src/modules/ROOT/pages/constraints-and-score/performance.adoc index 90485bb6c26..6511571e716 100644 --- a/docs/src/modules/ROOT/pages/constraints-and-score/performance.adoc +++ b/docs/src/modules/ROOT/pages/constraints-and-score/performance.adoc @@ -88,6 +88,19 @@ set the xref:constraints-and-score/constraint-configuration.adoc#definingAndOver In xref:constraints-and-score/score-calculation.adoc#constraintStreams[Constraint Streams], this will result in the constraint not being added to the node network and therefore have no performance impact at all. +Should you forget to set the constraint weight to zero, +the constraint may still be deactivated at runtime, +if the engine can prove that it will never contribute to the score. +For example, a constraint using `forEach(Visit.class)` will be deactivated +if there are no `Visit` instances in the solution. + +This is inferior to manually setting the constraint weight to zero, +as the engine can only perform optimizations it can guarantee are safe +based on the limited information it has. +With full understanding of the domain, +you are in a much better position to know which constraints are relevant and which are not, +and you can set the weights accordingly. + [#buildInHardConstraint] == Built-in hard constraint @@ -100,7 +113,7 @@ to define that Course A should only be assigned a `Room` different than X. This can give a good performance gain in some use cases, not just because the move evaluation is faster, but mainly because most optimization algorithms will spend less time evaluating infeasible solutions. -However, usually this is not a good idea because there is a real risk of trading short-term benefits for long-term harm: +However, there is a real risk of trading short-term benefits for long-term harm: * Many optimization algorithms rely on the freedom to break hard constraints when changing planning entities, to get out of local optima.