Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.mx.path.gateway.connect.filter;

import com.mx.path.core.common.connect.Request;
import com.mx.path.core.common.connect.RequestFilterBase;
import com.mx.path.core.common.connect.Response;
import com.mx.path.core.context.RequestContext;
import com.mx.path.core.context.UpstreamRequestConfiguration;

/**
* Request Filter that adds forwarded headers to request
*/
public class HeaderForwarderFilter extends RequestFilterBase {

@Override
public final void execute(Request request, Response response) {
UpstreamRequestConfiguration upstreamRequestConfiguration = null;

if (RequestContext.current() != null) {
upstreamRequestConfiguration = RequestContext.current().getUpstreamRequestConfiguration();
}

if (upstreamRequestConfiguration != null && upstreamRequestConfiguration.getForwardedRequestHeaders() != null) {
upstreamRequestConfiguration.getForwardedRequestHeaders().forEach((key, value) -> {
request.withHeader(key, value.toString());
});
}

next(request, response);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package com.mx.path.gateway.connect.filter

import static org.mockito.Mockito.mock

import com.mx.path.core.common.connect.Request
import com.mx.path.core.common.connect.RequestFilter
import com.mx.path.core.common.connect.Response
import com.mx.path.core.context.RequestContext
import com.mx.path.core.context.UpstreamRequestConfiguration

import spock.lang.Specification

class HeaderForwarderFilterTest extends Specification {
class TestRequest extends Request<TestRequest, TestResponse> {
TestRequest(RequestFilter filterChain) {
super(filterChain)
}

@Override
TestResponse execute() {
return null
}
}

class TestResponse extends Response<TestRequest, TestResponse> {
}

def "forwards all headers in current RequestContext forwardedHeader"() {
given:
def requestFilter = mock(RequestFilter)
def subject = new HeaderForwarderFilter()
def request = new TestRequest(requestFilter)
def response = new TestResponse()
RequestContext.builder()
.upstreamRequestConfiguration(UpstreamRequestConfiguration.builder()
.forwardedHeader("mx_forwarded_important_header", "12345")
.forwardedHeader("mx-forwarded-important-header", "12345")
.build())
.build()
.register()

when:
subject.execute(request, response)

then:
request.headers.containsKey("mx_forwarded_important_header")
request.headers.containsKey("mx-forwarded-important-header")
}
}