Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 131 additions & 51 deletions src/passes/GlobalEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
// PassOptions structure; see more details there.
//

#include <ranges>

#include "ir/effects.h"
#include "ir/module-utils.h"
#include "pass.h"
#include "support/graph_traversal.h"
#include "support/strongly_connected_components.h"
#include "wasm.h"

Expand All @@ -39,6 +42,9 @@ struct FuncInfo {

// Directly-called functions from this function.
std::unordered_set<Name> calledFunctions;

// Types that are targets of indirect calls.
std::unordered_set<HeapType> indirectCalledTypes;
};

std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
Expand Down Expand Up @@ -83,11 +89,21 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
if (auto* call = curr->dynCast<Call>()) {
// Note the direct call.
funcInfo.calledFunctions.insert(call->target);
} else if (effects.calls && options.closedWorld) {
HeapType type;
if (auto* callRef = curr->dynCast<CallRef>()) {
// call_ref on unreachable does not have a call effect,
// so this must be a HeapType.
type = callRef->target->type.getHeapType();
Comment thread
stevenfontanella marked this conversation as resolved.
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
type = callIndirect->heapType;
} else {
WASM_UNREACHABLE("Unexpected call type");
}

funcInfo.indirectCalledTypes.insert(type);
} else if (effects.calls) {
// This is an indirect call of some sort, so we must assume the
// worst. To do so, clear the effects, which indicates nothing
// is known (so anything is possible).
// TODO: We could group effects by function type etc.
assert(!options.closedWorld);
funcInfo.effects = UnknownEffects;
} else {
// No call here, but update throwing if we see it. (Only do so,
Expand All @@ -107,22 +123,86 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
return std::move(analysis.map);
}

using CallGraph = std::unordered_map<Function*, std::unordered_set<Function*>>;
using CallGraphNode = std::variant<Function*, HeapType>;
Comment thread
stevenfontanella marked this conversation as resolved.

/*
Call graph for indirect and direct calls.

key (caller) -> value (callee)
Function -> Function : direct call
Function -> HeapType : indirect call to the given HeapType
HeapType -> Function : The function `callee` has the type `caller`. The
HeapType may essentially 'call' any of its
potential implementations.
HeapType -> HeapType : `callee` is a subtype of `caller`. A call_ref
could target any subtype of the ref, so we need to
aggregate effects of subtypes of the target type.

If we're running in an open world, we only include Function -> Function edges,
and don't compute effects for indirect calls, conservatively assuming the
worst.
*/
using CallGraph =
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>;

CallGraph buildCallGraph(const Module& module,
const std::map<Function*, FuncInfo>& funcInfos) {
const std::map<Function*, FuncInfo>& funcInfos,
bool closedWorld) {
CallGraph callGraph;
for (const auto& [func, info] : funcInfos) {
if (info.calledFunctions.empty()) {
continue;
if (!closedWorld) {
for (const auto& [caller, callerInfo] : funcInfos) {
Comment on lines +152 to +153
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have to add the function -> function direct call edges whether or not we're in a closed world? I would expect this condition to guard just the early return below.

auto& callees = callGraph[caller];

// Function -> Function
for (Name calleeFunction : callerInfo.calledFunctions) {
callees.insert(module.getFunction(calleeFunction));
}
}

auto& callees = callGraph[func];
for (Name callee : info.calledFunctions) {
callees.insert(module.getFunction(callee));
return callGraph;
}

std::unordered_set<HeapType> allFunctionTypes;
Comment thread
stevenfontanella marked this conversation as resolved.
for (const auto& [caller, callerInfo] : funcInfos) {
auto& callees = callGraph[caller];

// Function -> Function
for (Name calleeFunction : callerInfo.calledFunctions) {
callees.insert(module.getFunction(calleeFunction));
}

// Function -> Type
allFunctionTypes.insert(caller->type.getHeapType());
for (HeapType calleeType : callerInfo.indirectCalledTypes) {
callees.insert(calleeType);

// Add the key to ensure the lookup doesn't fail for indirect calls to
// uninhabited types.
callGraph[calleeType];
}

// Type -> Function
callGraph[caller->type.getHeapType()].insert(caller);
}

// Type -> Type
// Do a DFS up the type heirarchy for all function implementations.
// We are essentially walking up each supertype chain and adding edges from
// super -> subtype, but doing it via DFS to avoid repeated work.
Graph superTypeGraph(allFunctionTypes.begin(),
allFunctionTypes.end(),
[&callGraph](auto&& push, HeapType t) {
// Not needed except that during lookup we expect the
// key to exist.
callGraph[t];

if (auto super = t.getDeclaredSuperType()) {
callGraph[*super].insert(t);
push(*super);
}
});
(void)superTypeGraph.traverseDepthFirst();

return callGraph;
}

Expand Down Expand Up @@ -152,63 +232,60 @@ void propagateEffects(const Module& module,
const PassOptions& passOptions,
std::map<Function*, FuncInfo>& funcInfos,
const CallGraph& callGraph) {
// We only care about Functions that are roots, not types.
// A type would be a root if a function exists with that type, but no-one
// indirect calls the type.
auto funcNodes = std::views::keys(callGraph) |
std::views::filter([](auto node) {
return std::holds_alternative<Function*>(node);
}) |
std::views::common;
Comment thread
tlively marked this conversation as resolved.
using funcNodesType = decltype(funcNodes);

struct CallGraphSCCs
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs> {
: SCCs<std::ranges::iterator_t<funcNodesType>, CallGraphSCCs> {

const std::map<Function*, FuncInfo>& funcInfos;
const std::unordered_map<Function*, std::unordered_set<Function*>>&
callGraph;
const CallGraph& callGraph;
const Module& module;

CallGraphSCCs(
const std::vector<Function*>& funcs,
const std::map<Function*, FuncInfo>& funcInfos,
const std::unordered_map<Function*, std::unordered_set<Function*>>&
callGraph,
const Module& module)
: SCCs<std::vector<Function*>::const_iterator, CallGraphSCCs>(
funcs.begin(), funcs.end()),
CallGraphSCCs(funcNodesType&& nodes,
const std::map<Function*, FuncInfo>& funcInfos,
const CallGraph& callGraph,
const Module& module)
: SCCs<std::ranges::iterator_t<funcNodesType>, CallGraphSCCs>(
std::ranges::begin(nodes), std::ranges::end(nodes)),
funcInfos(funcInfos), callGraph(callGraph), module(module) {}

void pushChildren(Function* f) {
auto callees = callGraph.find(f);
if (callees == callGraph.end()) {
return;
}

for (auto* callee : callees->second) {
void pushChildren(CallGraphNode node) {
for (CallGraphNode callee : callGraph.at(node)) {
push(callee);
}
}
};

std::vector<Function*> allFuncs;
for (auto& [func, info] : funcInfos) {
allFuncs.push_back(func);
}
CallGraphSCCs sccs(allFuncs, funcInfos, callGraph, module);
CallGraphSCCs sccs(std::move(funcNodes), funcInfos, callGraph, module);

std::vector<std::optional<EffectAnalyzer>> componentEffects;
// Points to an index in componentEffects
std::unordered_map<Function*, Index> funcComponents;
std::unordered_map<CallGraphNode, Index> nodeComponents;

for (auto ccIterator : sccs) {
std::optional<EffectAnalyzer>& ccEffects =
componentEffects.emplace_back(std::in_place, passOptions, module);
std::vector<CallGraphNode> cc(ccIterator.begin(), ccIterator.end());

std::vector<Function*> ccFuncs(ccIterator.begin(), ccIterator.end());

for (Function* f : ccFuncs) {
funcComponents.emplace(f, componentEffects.size() - 1);
std::vector<Function*> ccFuncs;
for (CallGraphNode node : cc) {
nodeComponents.emplace(node, componentEffects.size() - 1);
if (auto** func = std::get_if<Function*>(&node)) {
ccFuncs.push_back(*func);
}
}

std::unordered_set<int> calleeSccs;
for (Function* caller : ccFuncs) {
auto callees = callGraph.find(caller);
if (callees == callGraph.end()) {
continue;
}
for (auto* callee : callees->second) {
calleeSccs.insert(funcComponents.at(callee));
for (CallGraphNode caller : cc) {
for (CallGraphNode callee : callGraph.at(caller)) {
calleeSccs.insert(nodeComponents.at(callee));
}
}

Expand All @@ -219,11 +296,13 @@ void propagateEffects(const Module& module,
}

// Add trap effects for potential cycles.
if (ccFuncs.size() > 1) {
if (cc.size() > 1) {
if (ccEffects != UnknownEffects) {
ccEffects->trap = true;
}
} else {
} else if (ccFuncs.size() == 1) {
// It's possible for a CC to only contain 1 type, but that is not a
// cycle in the call graph.
auto* func = ccFuncs[0];
if (funcInfos.at(func).calledFunctions.contains(func->name)) {
if (ccEffects != UnknownEffects) {
Expand Down Expand Up @@ -267,7 +346,8 @@ struct GenerateGlobalEffects : public Pass {
std::map<Function*, FuncInfo> funcInfos =
analyzeFuncs(*module, getPassOptions());

auto callGraph = buildCallGraph(*module, funcInfos);
auto callGraph =
buildCallGraph(*module, funcInfos, getPassOptions().closedWorld);

propagateEffects(*module, getPassOptions(), funcInfos, callGraph);

Expand Down
72 changes: 72 additions & 0 deletions src/support/graph_traversal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright 2026 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <concepts>
#include <functional>
#include <iterator>
#include <unordered_set>

namespace wasm {

template<typename T, typename SuccessorFunction>
requires std::
invocable<SuccessorFunction, std::function<void(const T&)>&, const T&>
class Graph {
public:
template<std::input_iterator It, std::sentinel_for<It> Sen>
requires std::convertible_to<std::iter_reference_t<It>, T>
Graph(It rootsBegin, Sen rootsEnd, auto&& successors)
: roots(rootsBegin, rootsEnd),
successors(std::forward<decltype(successors)>(successors)) {}

// Traverse the graph depth-first, calling `successors` exactly once for each
// node (unless the node appears multiple times in `roots`). Return the set of
// nodes visited.
std::unordered_set<T> traverseDepthFirst() const {
std::vector<T> stack(roots.begin(), roots.end());
std::unordered_set<T> visited(roots.begin(), roots.end());

auto maybePush = [&](const T& t) {
if (visited.contains(t)) {
return;
}

visited.insert(t);
Comment on lines +43 to +47
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can combine these lookups by using the bool result of insert.

stack.push_back(t);
};

while (!stack.empty()) {
auto curr = std::move(stack.back());
stack.pop_back();

successors(maybePush, curr);
}

return visited;
}

private:
std::vector<T> roots;
SuccessorFunction successors;
};

template<std::input_iterator It,
std::sentinel_for<It> Sen,
typename SuccessorFunction>
Graph(It, Sen, SuccessorFunction)
-> Graph<std::iter_value_t<It>, std::decay_t<SuccessorFunction>>;

} // namespace wasm
Loading
Loading