Set TF_CPP_MIN_LOG_LEVEL=0 in __init__.py#3131
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
|
🤖 Hi @igorts-git, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This Pull Request effectively addresses an issue where TF_CPP_MIN_LOG_LEVEL was not being set early enough to influence JAX's C++ logging, by moving the environment variable setting to __init__.py. The change is well-explained and improves logging visibility.
🔍 General Feedback
- The core change to set
TF_CPP_MIN_LOG_LEVEL=0earlier is a good improvement for debugging and visibility. - Consider adding an automated unit test to ensure the logging behavior is consistent across environments and to prevent future regressions.
src/MaxText/__init__.py
Outdated
| import os as _os | ||
| # In order to have any effect on the C++ logging this has to be set before we import anything from jax. | ||
| # When jax is imported, its `__init__.py` calls `cloud_tpu_init()`, which also initializes the C++ logger. | ||
| _os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "0") |
There was a problem hiding this comment.
🟡 It would be beneficial to add a unit test to programmatically verify that setting TF_CPP_MIN_LOG_LEVEL to '0' here correctly enables the desired C++ logging output, especially the trace path for profiler=xplane. This would help prevent regressions and ensure the intended behavior persists.
There was a problem hiding this comment.
I could not figure out how to add a reliable unit test for this. The tests are located outside of the "src/MaxText" directory and invoke train.main() as a function. This means that the order of imports in the test itself would affect the C++ logging behavior. Specifically, if the test imports JAX before "MaxText.train" the logging won't work.
Additionally, it is not easy to figure out how to capture C++ LOG(INFO) outputs in a unit tests. The standard Python unit testing libs only capture logs coming from Python.
f977063 to
67203de
Compare
…Setting this env variable after `import jax` is called has no effect, because the C++ logging system is already initialized.
67203de to
cdec380
Compare
gobbleturk
left a comment
There was a problem hiding this comment.
LGTM, why del os though?
I am not a Python expert, but I understand that this is the common convention to avoid pollution of the namespace. Without this |
Description
Set
TF_CPP_MIN_LOG_LEVEL=0in__init__.pyinstead oftrain.py.Setting this env variable after
import jaxis called has no effect, because the C++ logging system is already initialized.The
jax/__init__.pysets the default value ofTF_CPP_MIN_LOG_LEVELto "1" (code pointer). Then it calls_cloud_tpu_init(), which also initializes the C++ logging system.There are many useful C++ LOG(INFO) messages that are suppressed without
TF_CPP_MIN_LOG_LEVEL=0. For example, the path of the created profile trace (when running withprofiler=xplaneflag).FIXES: b/474109045
Tests
Local run of train.py with
profile=xplaneto confirm that the trace path is printed.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.