diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java index b6dc6a95e461a..e9926bee1f787 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/AdaptiveGraphManager.java @@ -42,6 +42,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -430,7 +431,9 @@ private void setVertexNonChainedOutputsConfig( private void connectToFinishedUpStreamVertex(JobVertexBuildContext jobVertexBuildContext) { Map chainInfos = jobVertexBuildContext.getChainInfosInOrder(); for (OperatorChainInfo chainInfo : chainInfos.values()) { - List transitiveInEdges = chainInfo.getTransitiveInEdges(); + List transitiveInEdges = + getTransitiveInEdgesInOrder( + chainInfo.getTransitiveInEdges(), jobVertexBuildContext); for (StreamEdge transitiveInEdge : transitiveInEdges) { NonChainedOutput output = intermediateOutputsCaches @@ -447,6 +450,55 @@ private void connectToFinishedUpStreamVertex(JobVertexBuildContext jobVertexBuil } } + private List getTransitiveInEdgesInOrder( + List transitiveInEdges, JobVertexBuildContext jobVertexBuildContext) { + final List transitiveInEdgesInOrder = + transitiveInEdges.stream() + .sorted( + Comparator.comparing( + inEdge -> getStartNodeId(inEdge.getSourceId()))) + .collect(Collectors.toList()); + final List uidTransitiveInEdges = + transitiveInEdgesInOrder.stream() + .filter(this::hasUidBackedUpstream) + .sorted( + Comparator.comparing( + inEdge -> + getStartNodeJobVertexId( + inEdge, jobVertexBuildContext))) + .collect(Collectors.toList()); + + if (uidTransitiveInEdges.size() < 2) { + return transitiveInEdgesInOrder; + } + + int uidTransitiveInEdgeIndex = 0; + for (int transitiveInEdgeIndex = 0; + transitiveInEdgeIndex < transitiveInEdgesInOrder.size(); + transitiveInEdgeIndex++) { + if (hasUidBackedUpstream(transitiveInEdgesInOrder.get(transitiveInEdgeIndex))) { + transitiveInEdgesInOrder.set( + transitiveInEdgeIndex, + uidTransitiveInEdges.get(uidTransitiveInEdgeIndex++)); + } + } + return transitiveInEdgesInOrder; + } + + private boolean hasUidBackedUpstream(StreamEdge inEdge) { + return streamGraph + .getStreamNode(getStartNodeId(inEdge.getSourceId())) + .getTransformationUID() + != null; + } + + private JobVertexID getStartNodeJobVertexId( + StreamEdge inEdge, JobVertexBuildContext jobVertexBuildContext) { + return new JobVertexID( + Preconditions.checkNotNull( + jobVertexBuildContext.getHash(getStartNodeId(inEdge.getSourceId())))); + } + private void recordCreatedJobVerticesInfo(JobVertexBuildContext jobVertexBuildContext) { Map chainInfos = jobVertexBuildContext.getChainInfosInOrder(); for (OperatorChainInfo chainInfo : chainInfos.values()) { @@ -473,11 +525,17 @@ private void createOperatorChainInfos( final Map chainEntryPoints = buildAndGetChainEntryPoints(streamNodes, jobVertexBuildContext); + chainEntryPoints.values().stream() + .filter( + chainInfo -> + streamGraph + .getStreamNode(chainInfo.getStartNodeId()) + .getTransformationUID() + != null) + .forEach(chainInfo -> generateHashesByStreamNodeId(chainInfo.getStartNodeId())); final Collection initialEntryPoints = - chainEntryPoints.entrySet().stream() - .sorted(Map.Entry.comparingByKey()) - .map(Map.Entry::getValue) - .collect(Collectors.toList()); + StreamingJobGraphGenerator.getInitialEntryPoints( + chainEntryPoints.values(), jobVertexBuildContext); for (OperatorChainInfo info : initialEntryPoints) { // We use generateHashesByStreamNodeId to subscribe the visited stream node id and diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index 66776c03fb5fd..27ee75c6add6b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -607,10 +607,7 @@ private void setChaining() { final Map chainEntryPoints = buildChainedInputsAndGetHeadInputs(); final Collection initialEntryPoints = - chainEntryPoints.entrySet().stream() - .sorted(Comparator.comparing(Map.Entry::getKey)) - .map(Map.Entry::getValue) - .collect(Collectors.toList()); + getInitialEntryPoints(chainEntryPoints.values(), jobVertexBuildContext); // iterate over a copy of the values, because this map gets concurrently modified for (OperatorChainInfo info : initialEntryPoints) { @@ -626,6 +623,57 @@ private void setChaining() { } } + static Collection getInitialEntryPoints( + Collection chainEntryPoints, + JobVertexBuildContext jobVertexBuildContext) { + final List initialEntryPoints = + chainEntryPoints.stream() + .sorted(Comparator.comparing(OperatorChainInfo::getStartNodeId)) + .collect(Collectors.toList()); + final List uidEntryPoints = + initialEntryPoints.stream() + .filter( + chainInfo -> + hasTransformationUid( + chainInfo, jobVertexBuildContext.getStreamGraph())) + .sorted( + Comparator.comparing( + chainInfo -> + getStartNodeJobVertexId( + chainInfo, jobVertexBuildContext))) + .collect(Collectors.toList()); + + if (uidEntryPoints.size() < 2) { + return initialEntryPoints; + } + + int uidEntryPointIndex = 0; + for (int entryPointIndex = 0; + entryPointIndex < initialEntryPoints.size(); + entryPointIndex++) { + if (hasTransformationUid( + initialEntryPoints.get(entryPointIndex), + jobVertexBuildContext.getStreamGraph())) { + // Input gates are restored by position, so align source traversal with the stable + // JobVertexID for UIDed heads while keeping the other heads in their legacy + // positions. + initialEntryPoints.set(entryPointIndex, uidEntryPoints.get(uidEntryPointIndex++)); + } + } + return initialEntryPoints; + } + + private static boolean hasTransformationUid( + OperatorChainInfo chainInfo, StreamGraph streamGraph) { + return streamGraph.getStreamNode(chainInfo.getStartNodeId()).getTransformationUID() != null; + } + + private static JobVertexID getStartNodeJobVertexId( + OperatorChainInfo chainInfo, JobVertexBuildContext jobVertexBuildContext) { + return new JobVertexID( + checkNotNull(jobVertexBuildContext.getHash(chainInfo.getStartNodeId()))); + } + public static List createChain( final Integer currentNodeId, final int chainIndex, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/JobGraphGeneratorTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/JobGraphGeneratorTestBase.java index 65260d0d0df3a..4dde81304b3e8 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/JobGraphGeneratorTestBase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/JobGraphGeneratorTestBase.java @@ -1187,6 +1187,24 @@ void testDeterministicUnionOrder() { } } + @Test + void testStableUnionInputOrderWithOperatorUids() { + assertThat(getInputSourceNames(getUidUnionJobGraph(false))) + .isEqualTo(getInputSourceNames(getUidUnionJobGraph(true))); + } + + @Test + void testStableUnionInputOrderWithOperatorUidsAndUnrelatedSource() { + assertThat(getMultiInputSourceNames(getUidUnionJobGraphWithUnrelatedSource(false))) + .isEqualTo(getMultiInputSourceNames(getUidUnionJobGraphWithUnrelatedSource(true))); + } + + @Test + void testUidHashUnionInputOrderKeepsDeclarationOrder() { + assertThat(getInputSourceNames(getUidHashUnionJobGraph())) + .containsExactly("Source: source-b", "Source: source-a"); + } + private JobGraph getUnionJobGraph(StreamExecutionEnvironment env) { createSource(env, 1) @@ -1202,6 +1220,87 @@ private DataStream createSource(StreamExecutionEnvironment env, int ind return env.fromData(index).name("source" + index).map(i -> i).name("map" + index); } + private JobGraph getUidUnionJobGraph(boolean reverseSources) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(1); + env.disableOperatorChaining(); + + DataStream firstSource = + createUidSource( + env, reverseSources ? "source-b" : "source-a", reverseSources ? 2 : 1); + DataStream secondSource = + createUidSource( + env, reverseSources ? "source-a" : "source-b", reverseSources ? 1 : 2); + + firstSource.union(secondSource).sinkTo(new DiscardingSink<>()).name("sink").uid("sink"); + + return createJobGraph(env.getStreamGraph()); + } + + private DataStream createUidSource( + StreamExecutionEnvironment env, String sourceName, int value) { + return env.fromData(value).name(sourceName).uid(sourceName); + } + + private JobGraph getUidUnionJobGraphWithUnrelatedSource(boolean reverseSources) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(1); + env.disableOperatorChaining(); + + env.fromData(0) + .name("unrelated-source") + .sinkTo(new DiscardingSink<>()) + .name("unrelated-sink"); + + DataStream firstSource = + createUidSource( + env, reverseSources ? "source-b" : "source-a", reverseSources ? 2 : 1); + DataStream secondSource = + createUidSource( + env, reverseSources ? "source-a" : "source-b", reverseSources ? 1 : 2); + + firstSource.union(secondSource).sinkTo(new DiscardingSink<>()).name("sink").uid("sink"); + + return createJobGraph(env.getStreamGraph()); + } + + private JobGraph getUidHashUnionJobGraph() { + StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment(1); + env.disableOperatorChaining(); + + DataStream firstSource = + createUidHashSource(env, "source-b", 2, "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"); + DataStream secondSource = + createUidHashSource(env, "source-a", 1, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + + firstSource.union(secondSource).sinkTo(new DiscardingSink<>()).name("sink").uid("sink"); + + return createJobGraph(env.getStreamGraph()); + } + + private DataStream createUidHashSource( + StreamExecutionEnvironment env, String sourceName, int value, String uidHash) { + return env.fromData(value).name(sourceName).setUidHash(uidHash); + } + + private List getInputSourceNames(JobGraph jobGraph) { + JobVertex jobSink = Iterables.getLast(jobGraph.getVerticesSortedTopologicallyFromSources()); + return getInputSourceNames(jobSink); + } + + private List getMultiInputSourceNames(JobGraph jobGraph) { + JobVertex jobSink = + jobGraph.getVerticesSortedTopologicallyFromSources().stream() + .filter(jobVertex -> jobVertex.getInputs().size() > 1) + .findFirst() + .orElseThrow(() -> new AssertionError("Expected a multi-input sink")); + return getInputSourceNames(jobSink); + } + + private List getInputSourceNames(JobVertex jobSink) { + return jobSink.getInputs().stream() + .map(edge -> edge.getSource().getProducer().getName()) + .collect(Collectors.toList()); + } + @Test void testNotSupportInputSelectableOperatorIfCheckpointing() { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();