Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,33 @@
import java.lang.reflect.Method;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public class ArtemisMBeanServerGuard implements GuardInvocationHandler {

private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private JMXAccessControlList jmxAccessControlList = JMXAccessControlList.createDefaultList();

private final Map<ObjectName, Boolean> bypassRBACCache = new ConcurrentHashMap<>();
private final Map<String, ObjectName> objectNameCache = new ConcurrentHashMap<>();

private static final class CachedRolesPrincipal implements Principal {
final Set<String> roles;

CachedRolesPrincipal(Set<String> roles) {
this.roles = Collections.unmodifiableSet(roles);
}

@Override public String getName() {
return "__cached_roles__";
}
}

public void init() {
ArtemisMBeanServerBuilder.setGuard(this);
}
Expand Down Expand Up @@ -122,18 +142,12 @@ private void handleSetAttributes(MBeanServer proxy, ObjectName objectName, Attri
}

private boolean canBypassRBAC(ObjectName objectName) {
return jmxAccessControlList.isInAllowList(objectName);
return bypassRBACCache.computeIfAbsent(objectName, name -> jmxAccessControlList.isInAllowList(name));
}

@Override
public boolean canInvoke(String object, String operationName) {
ObjectName objectName = null;
try {
objectName = ObjectName.getInstance(object);
} catch (MalformedObjectNameException e) {
logger.debug("can't check invoke rights as object name invalid: {}", object, e);
return false;
}

/*
* HawtIO calls this with a null operationName as a coarse grained way of authenticating against all the
* operations on an mbean. Until this addition this was throwing a null pointer on operationName later in this
Expand All @@ -142,7 +156,23 @@ public boolean canInvoke(String object, String operationName) {
* it. Since it is just an optimisation it is fine to always return true. Note that the alternative
* ArtemisRbacInvocationHandler does allow the ability to restrict a whole mbean.
*/
if (operationName == null || canBypassRBAC(objectName)) {
if (operationName == null) {
return true;
}
ObjectName objectName = objectNameCache.computeIfAbsent(object, key -> {
try {
return ObjectName.getInstance(key);
} catch (MalformedObjectNameException e) {
logger.debug("can't check invoke rights as object name invalid: {}", key, e);
return null;
}
});

if (objectName == null) {
return false;
}

if (canBypassRBAC(objectName)) {
return true;
}

Expand All @@ -151,15 +181,21 @@ public boolean canInvoke(String object, String operationName) {
if (paramListIndex > 0) {
operationName = operationName.substring(0, paramListIndex);
}
Set<String> currentUserRoles = getCurrentUserRoles();

List<String> requiredRoles = getRequiredRoles(objectName, operationName);
for (String role : requiredRoles) {
if (currentUserHasRole(role)) {
return true;
}
if (currentUserRoles.isEmpty()) {
return false;
}
logger.debug("{} {} false", object, operationName);
return false;
boolean authorized = authorizeUserForMethod(objectName, operationName, currentUserRoles);

if (authorized) {
logger.debug("{} {} true", object, operationName);
return true;
} else {
logger.debug("{} {} false", object, operationName);
return false;
}

}

void handleInvoke(ObjectName objectName, String operationName) throws IOException {
Expand All @@ -182,6 +218,10 @@ List<String> getRequiredRoles(ObjectName objectName, String methodName) {
return jmxAccessControlList.getRolesForObject(objectName, methodName);
}

boolean authorizeUserForMethod(ObjectName objectName, String operationName, Set<String> currentUserRoles) {
return jmxAccessControlList.authorizeUserForMethod(objectName, operationName, currentUserRoles);
}

public void setJMXAccessControlList(JMXAccessControlList JMXAccessControlList) {
this.jmxAccessControlList = JMXAccessControlList;
}
Expand Down Expand Up @@ -210,4 +250,26 @@ public static boolean currentUserHasRole(String requestedRole) {
}
return false;
}

public static Set<String> getCurrentUserRoles() {
Subject subject = SecurityManagerShim.currentSubject();
if (subject == null) {
return Collections.emptySet();
}

// Check if roles are already cached on the subject
Set<CachedRolesPrincipal> cached = subject.getPrincipals(CachedRolesPrincipal.class);
if (!cached.isEmpty()) {
return cached.iterator().next().roles;
}

// First call for this subject — build and cache
Set<String> roles = new HashSet<>();
for (Principal p : subject.getPrincipals()) {
roles.add(p.getName());
}
subject.getPrincipals().add(new CachedRolesPrincipal(roles));
return roles;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import javax.management.ObjectName;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand All @@ -31,6 +33,36 @@
public class JMXAccessControlList {
private static final String WILDCARD = "*";

private record AccessEntry(Access access, String rawPattern) { }
private record Bucket(
Map<String, AccessEntry> exactMatches,
List<AccessEntry> regexPatterns
) { }

private final Map<String, Map<String, String>> keyPropertyCache =
Collections.synchronizedMap(new LinkedHashMap<String, Map<String, String>>(128, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, Map<String, String>> eldest) {
return size() > 5000;
}
});

private final Map<String, TreeMap<String, Access>> domainCache =
Collections.synchronizedMap(new LinkedHashMap<String, TreeMap<String, Access>>(128, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, TreeMap<String, Access>> eldest) {
return size() > 5000;
}
});

private final Map<String, Map<String, Bucket>> bucketedDomainCache =
Collections.synchronizedMap(new LinkedHashMap<>(128, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, Map<String, Bucket>> eldest) {
return size() > 1000;
}
});

private Access defaultAccess = new Access(WILDCARD);
private ConcurrentMap<String, TreeMap<String, Access>> domainAccess = new ConcurrentHashMap<>();
private ConcurrentMap<String, TreeMap<String, Access>> allowList = new ConcurrentHashMap<>();
Expand All @@ -48,6 +80,7 @@ public class JMXAccessControlList {
return key2.length() - key1.length();
};


public void addToAllowList(String domain, String key) {
TreeMap<String, Access> domainMap = new TreeMap<>(keyComparator);
domainMap = allowList.putIfAbsent(domain, domainMap);
Expand Down Expand Up @@ -81,6 +114,86 @@ public List<String> getRolesForObject(ObjectName objectName, String methodName)
return defaultAccess.getMatchingRolesForMethod(methodName);
}


public boolean authorizeUserForMethod(ObjectName objectName, String methodName, Set<String> userRoles) {

String domainKey = objectName.getDomain();

TreeMap<String, Access> domainMap = domainCache.computeIfAbsent(objectName.getDomain(), key ->
domainAccess.get(key)
);

Map<String, Bucket> bucketedMap = bucketedDomainCache.computeIfAbsent(domainKey, d -> {
TreeMap<String, Access> rawMap = domainAccess.get(d);
if (rawMap == null) {
return null;
}

Map<String, Bucket> grouped = new HashMap<>();
for (Access access : rawMap.values()) {
String rawPattern = access.getKeyPattern().pattern();
int eqIndex = rawPattern.indexOf('=');
String prefix = (eqIndex != -1) ? rawPattern.substring(0, eqIndex) : "";

// Initialize the Bucket (Map + List) instead of just an ArrayList
Bucket bucket = grouped.computeIfAbsent(prefix, k ->
new Bucket(new HashMap<>(), new ArrayList<>())
);

AccessEntry entry = new AccessEntry(access, rawPattern);

// Sort into Exact or Regex
if (rawPattern.contains("*") || rawPattern.contains("?") || rawPattern.contains("[")) {
bucket.regexPatterns().add(entry);
} else {
bucket.exactMatches().put(rawPattern, entry);
}
}
return grouped;
});

if (bucketedMap != null) {

String cacheKey = objectName.getCanonicalName();
Map<String, String> keyPropertyList = keyPropertyCache.get(cacheKey);
if (keyPropertyList == null) {
keyPropertyList = objectName.getKeyPropertyList();
keyPropertyCache.put(cacheKey, keyPropertyList);
}


for (Map.Entry<String, String> entry : keyPropertyList.entrySet()) {
String propKey = entry.getKey();
Bucket bucket = bucketedMap.get(propKey);

if (bucket != null) {
String normalizedValue = normalizeKey(propKey + "=" + entry.getValue());

// Priority 1: O(1) Exact Match Check
if (bucket.exactMatches().containsKey(normalizedValue)) {
return bucket.exactMatches().get(normalizedValue).access().authorizeUserForMethod(methodName, userRoles);
}

// Priority 2: O(N) Regex Match (but only for actual regexes)
for (AccessEntry regexEntry : bucket.regexPatterns()) {
if (regexEntry.access().getKeyPattern().matcher(normalizedValue).matches()) {
return regexEntry.access().authorizeUserForMethod(methodName, userRoles);
}
}
}
}

Access access = domainMap.get("");
if (access != null) {
return access.authorizeUserForMethod(methodName, userRoles);
}
}

return defaultAccess.authorizeUserForMethod(methodName, userRoles);
}



public boolean isInAllowList(ObjectName objectName) {
TreeMap<String, Access> domainMap = allowList.get(objectName.getDomain());

Expand Down Expand Up @@ -223,6 +336,20 @@ public List<String> getMatchingRolesForMethod(String methodName) {
}
return catchAllRoles;
}

public boolean authorizeUserForMethod(String methodName, Set<String> userRoles) {
List<String> roles = methodRoles.get(methodName);
if (roles != null) {
return !Collections.disjoint(roles, userRoles);

}
for (Map.Entry<String, List<String>> entry : methodPrefixRoles.entrySet()) {
if (methodName.startsWith(entry.getKey())) {
return !Collections.disjoint(entry.getValue(), userRoles);
}
}
return !Collections.disjoint(catchAllRoles, userRoles);
}
}

public static JMXAccessControlList createDefaultList() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void testCanInvokeMethodHasRole() throws Throwable {


@Test
public void testCanInvokeMethodDoeNotHasRole() throws Throwable {
public void testCanInvokeMethodDoesNotHaveRole() throws Throwable {
ArtemisMBeanServerGuard guard = new ArtemisMBeanServerGuard();
JMXAccessControlList controlList = new JMXAccessControlList();
guard.setJMXAccessControlList(controlList);
Expand Down
Loading