diff --git a/multiapps-mta/src/main/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilder.java b/multiapps-mta/src/main/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilder.java index cf65634a..addb942c 100644 --- a/multiapps-mta/src/main/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilder.java +++ b/multiapps-mta/src/main/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilder.java @@ -1,8 +1,10 @@ package org.cloudfoundry.multiapps.mta.builders; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import org.apache.commons.collections4.CollectionUtils; @@ -16,6 +18,8 @@ public class ExtensionDescriptorChainBuilder { private final boolean isStrict; + private ExtensionDescriptor secureExtensionDescriptor; + public ExtensionDescriptorChainBuilder() { this(true); } @@ -24,8 +28,11 @@ public ExtensionDescriptorChainBuilder(boolean isStrict) { this.isStrict = isStrict; } - public List build(DeploymentDescriptor deploymentDescriptor, List extensionDescriptors) - throws ContentException { + private boolean isSecureDescriptor(ExtensionDescriptor extensionDescriptor) { + return extensionDescriptor.getId().equals("__mta.secure"); + } + + public List build(DeploymentDescriptor deploymentDescriptor, List extensionDescriptors) throws ContentException { Map extensionDescriptorsPerParent = getExtensionDescriptorsPerParent(extensionDescriptors); return build(deploymentDescriptor, extensionDescriptorsPerParent); } @@ -34,11 +41,21 @@ private List build(DeploymentDescriptor deploymentDescripto Map extensionDescriptorsPerParent) { List chain = new ArrayList<>(); Descriptor currentDescriptor = deploymentDescriptor; + ExtensionDescriptor secureDescriptor = null; while (currentDescriptor != null) { ExtensionDescriptor nextDescriptor = extensionDescriptorsPerParent.remove(currentDescriptor.getId()); + if(nextDescriptor != null && isSecureDescriptor(nextDescriptor)) { + secureDescriptor = nextDescriptor; + continue; + } + CollectionUtils.addIgnoreNull(chain, nextDescriptor); currentDescriptor = nextDescriptor; } + CollectionUtils.addIgnoreNull(chain, secureDescriptor); + if(secureDescriptor == null && this.secureExtensionDescriptor != null) { + CollectionUtils.addIgnoreNull(chain, this.secureExtensionDescriptor); + } if (!extensionDescriptorsPerParent.isEmpty() && isStrict) { throw new ContentException(Messages.CANNOT_BUILD_EXTENSION_DESCRIPTOR_CHAIN_BECAUSE_DESCRIPTORS_0_HAVE_AN_UNKNOWN_PARENT, String.join(",", Descriptor.getIds(extensionDescriptorsPerParent.values()))); @@ -47,24 +64,40 @@ private List build(DeploymentDescriptor deploymentDescripto } private Map getExtensionDescriptorsPerParent(List extensionDescriptors) { - Map> extensionDescriptorsPerParent = extensionDescriptors.stream() - .collect(Collectors.groupingBy(ExtensionDescriptor::getParentId)); + Map> extensionDescriptorsPerParent = extensionDescriptors.stream().collect(Collectors.groupingBy(ExtensionDescriptor::getParentId)); + validateSingleExtensionDescriptorPerParent(extensionDescriptorsPerParent); return prune(extensionDescriptorsPerParent); } private Map prune(Map> extensionDescriptorsPerParent) { - validateSingleExtensionDescriptorPerParent(extensionDescriptorsPerParent); + this.secureExtensionDescriptor = null; return extensionDescriptorsPerParent.entrySet() .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue() - .get(0))); + .collect(Collectors.toMap(Map.Entry::getKey, entry -> { + List localList = entry.getValue(); + + for(ExtensionDescriptor extensionDescriptor : localList) { + if(extensionDescriptor.getId().equals("__mta.secure")) { + this.secureExtensionDescriptor = extensionDescriptor; + } + } + + for(ExtensionDescriptor extensionDescriptor : localList) { + if(!extensionDescriptor.getId().equals("__mta.secure")) { + return extensionDescriptor; + } + } + + return localList.get(0); + })); } private void validateSingleExtensionDescriptorPerParent(Map> extensionDescriptorsPerParent) { for (Map.Entry> extensionDescriptorsForParent : extensionDescriptorsPerParent.entrySet()) { String parent = extensionDescriptorsForParent.getKey(); List extensionDescriptors = extensionDescriptorsForParent.getValue(); - if (extensionDescriptors.size() > 1 && isStrict) { + long nonSecureCountOfExtensionDescriptors = extensionDescriptors.stream().filter(descriptor -> !descriptor.getId().equals("__mta.secure")).count(); + if (nonSecureCountOfExtensionDescriptors > 1 && isStrict) { throw new ContentException(Messages.MULTIPLE_EXTENSION_DESCRIPTORS_EXTEND_THE_PARENT_0, parent); } } diff --git a/multiapps-mta/src/test/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilderTest.java b/multiapps-mta/src/test/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilderTest.java index 1911e7e1..b4cd9fff 100644 --- a/multiapps-mta/src/test/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilderTest.java +++ b/multiapps-mta/src/test/java/org/cloudfoundry/multiapps/mta/builders/ExtensionDescriptorChainBuilderTest.java @@ -1,7 +1,10 @@ package org.cloudfoundry.multiapps.mta.builders; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.text.MessageFormat; import java.util.Arrays; @@ -102,6 +105,86 @@ void testLaxBuildWithMultipleDescriptorsExtendingTheSameParent() { assertEquals(Collections.emptyList(), extensionDescriptorChain); } + @Test + void testBuildWithNormalAndSecureExtensionDescriptorsWhereSecureIsLastInTheChain() { + DeploymentDescriptor deploymentDescriptor = buildDeploymentDescriptor(FOO_ID); + ExtensionDescriptor normalExtensionDescriptor = buildExtensionDescriptor(BAR_ID, FOO_ID); + ExtensionDescriptor secureExtensionDescriptor = buildSecureExtensionDescriptor(FOO_ID); + + List extensionDescriptors = Arrays.asList(secureExtensionDescriptor, normalExtensionDescriptor); + + ExtensionDescriptorChainBuilder extensionDescriptorChainBuilder = new ExtensionDescriptorChainBuilder(); + List extensionDescriptorChain = extensionDescriptorChainBuilder.build(deploymentDescriptor, extensionDescriptors); + + assertEquals(Arrays.asList(normalExtensionDescriptor, secureExtensionDescriptor), extensionDescriptorChain); + } + + @Test + void testBuildWithOnlySecureExtensionDescriptor() { + DeploymentDescriptor deploymentDescriptor = buildDeploymentDescriptor(FOO_ID); + ExtensionDescriptor secureExtensionDescriptor = buildSecureExtensionDescriptor(FOO_ID); + + List extensionDescriptors = Collections.singletonList(secureExtensionDescriptor); + + ExtensionDescriptorChainBuilder extensionDescriptorChainBuilder = new ExtensionDescriptorChainBuilder(); + List extensionDescriptorChain = extensionDescriptorChainBuilder.build(deploymentDescriptor, extensionDescriptors); + + assertEquals(Collections.singletonList(secureExtensionDescriptor), extensionDescriptorChain); + } + + @Test + void testBuildWhenSecureExtensionDescriptorIsExtendedStrictMode() { + DeploymentDescriptor deploymentDescriptor = buildDeploymentDescriptor(FOO_ID); + ExtensionDescriptor firstExtensionDescriptor = buildExtensionDescriptor(BAR_ID, FOO_ID); + ExtensionDescriptor secureExtensionDescriptor = buildSecureExtensionDescriptor(BAR_ID); + ExtensionDescriptor extensionDescriptorAfterSecureOne = buildExtensionDescriptor(QUX_ID, "__mta.secure"); + + List extensionDescriptors = Arrays.asList(extensionDescriptorAfterSecureOne, firstExtensionDescriptor, secureExtensionDescriptor); + + ExtensionDescriptorChainBuilder extensionDescriptorChainBuilder = new ExtensionDescriptorChainBuilder(true); + + ContentException contentException = assertThrows(ContentException.class, () -> extensionDescriptorChainBuilder.build(deploymentDescriptor, extensionDescriptors)); + + String expectedMessage = MessageFormat.format( + Messages.CANNOT_BUILD_EXTENSION_DESCRIPTOR_CHAIN_BECAUSE_DESCRIPTORS_0_HAVE_AN_UNKNOWN_PARENT, + extensionDescriptorAfterSecureOne.getId() + ); + assertEquals(expectedMessage, contentException.getMessage()); + } + + @Test + void testBuildWhenSecureExtensionDescriptorIsExtendedNonStrictMode() { + DeploymentDescriptor deploymentDescriptor = buildDeploymentDescriptor(FOO_ID); + ExtensionDescriptor firstExtensionDescriptor = buildExtensionDescriptor(BAR_ID, FOO_ID); + ExtensionDescriptor secureExtensionDescriptor = buildSecureExtensionDescriptor(BAR_ID); + ExtensionDescriptor extensionDescriptorAfterSecureOne = buildExtensionDescriptor(QUX_ID, "__mta.secure"); + + List extensionDescriptors = Arrays.asList(extensionDescriptorAfterSecureOne, secureExtensionDescriptor, firstExtensionDescriptor); + + ExtensionDescriptorChainBuilder extensionDescriptorChainBuilder = new ExtensionDescriptorChainBuilder(false); + List extensionDescriptorChain = extensionDescriptorChainBuilder.build(deploymentDescriptor, extensionDescriptors); + + assertEquals(Arrays.asList(firstExtensionDescriptor, secureExtensionDescriptor), extensionDescriptorChain); + } + + @Test + void testBuildWhenTwoSecureExtensionDescriptorsWhereOnlyTheLastOneGetsAppended() { + DeploymentDescriptor deploymentDescriptor = buildDeploymentDescriptor(FOO_ID); + ExtensionDescriptor firstExtensionDescriptor = buildExtensionDescriptor(BAR_ID, FOO_ID); + ExtensionDescriptor firstSecureExtensionDescriptor = buildSecureExtensionDescriptor(FOO_ID); + ExtensionDescriptor secondSecureExtensionDescriptor = buildSecureExtensionDescriptor(FOO_ID); + + List extensionDescriptors = Arrays.asList(firstSecureExtensionDescriptor, firstExtensionDescriptor, secondSecureExtensionDescriptor); + + ExtensionDescriptorChainBuilder extensionDescriptorChainBuilder = new ExtensionDescriptorChainBuilder(); + List extensionDescriptorChain = extensionDescriptorChainBuilder.build(deploymentDescriptor, extensionDescriptors); + + assertEquals(2, extensionDescriptorChain.size()); + assertEquals(firstExtensionDescriptor, extensionDescriptorChain.get(0)); + assertEquals("__mta.secure", extensionDescriptorChain.get(1).getId()); + assertSame(secondSecureExtensionDescriptor, extensionDescriptorChain.get(1)); + } + private DeploymentDescriptor buildDeploymentDescriptor(String id) { return DeploymentDescriptor.createV2() .setId(id); @@ -113,4 +196,10 @@ private ExtensionDescriptor buildExtensionDescriptor(String id, String parentId) .setId(id); } + private ExtensionDescriptor buildSecureExtensionDescriptor(String parentId) { + return ExtensionDescriptor.createV2() + .setParentId(parentId) + .setId("__mta.secure"); + } + }