Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/gigl/env/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
str
] = "COMPUTE_CLUSTER_LOCAL_WORLD_SIZE"

# Environment variable to indicate the component of the job.
# Values: "train", "inference"
GIGL_COMPONENT_ENV_KEY: Final[str] = "GIGL_COMPONENT"


@dataclass(frozen=True)
class DistributedContext:
Expand Down
19 changes: 19 additions & 0 deletions python/gigl/src/common/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

from gigl.env.distributed import GIGL_COMPONENT_ENV_KEY
from gigl.src.common.constants.components import GiGLComponents


def get_component() -> GiGLComponents:
"""Get the component of the current job.

Returns:
GiGLComponents: The component of the current job.
Raises:
ValueError: If the component is not valid.
"""
if GIGL_COMPONENT_ENV_KEY not in os.environ:
raise KeyError(
f"Environment variable {GIGL_COMPONENT_ENV_KEY} is not set. Cannot determine the component of the current job. Please set the environment variable like `export GIGL_COMPONENT=trainer`."
)
return GiGLComponents(os.environ[GIGL_COMPONENT_ENV_KEY])
14 changes: 9 additions & 5 deletions python/gigl/src/common/vertex_ai_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
)
from gigl.common.logger import Logger
from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService
from gigl.env.distributed import COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY
from gigl.env.distributed import (
COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY,
GIGL_COMPONENT_ENV_KEY,
)
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
GiglResourceConfigWrapper,
Expand Down Expand Up @@ -67,7 +70,6 @@ def launch_single_pool_job(
cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
container_uri = cpu_docker_uri if is_cpu_execution else cuda_docker_uri

job_config = _build_job_config(
job_name=job_name,
task_config_uri=task_config_uri,
Expand All @@ -77,7 +79,10 @@ def launch_single_pool_job(
use_cuda=is_cpu_execution,
container_uri=container_uri,
vertex_ai_resource_config=vertex_ai_resource_config,
env_vars=[env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3")],
env_vars=[
env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"),
env_var.EnvVar(name=GIGL_COMPONENT_ENV_KEY, value=component.value),
],
labels=resource_config_wrapper.get_resource_labels(component=component),
)
logger.info(f"Launching {component.value} job with config: {job_config}")
Expand Down Expand Up @@ -150,10 +155,10 @@ def launch_graph_store_enabled_job(
name=COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY,
value=str(num_compute_processes),
),
env_var.EnvVar(name=GIGL_COMPONENT_ENV_KEY, value=component.value),
]

labels = resource_config_wrapper.get_resource_labels(component=component)

# Create compute pool job config
compute_job_config = _build_job_config(
job_name=job_name,
Expand Down Expand Up @@ -221,7 +226,6 @@ def _build_job_config(

Args:
job_name (str): The base name for the job. Will be prefixed with "gigl_train_" or "gigl_infer_".
is_inference (bool): Whether this is an inference job (True) or training job (False).
task_config_uri (Uri): URI to the task configuration file.
resource_config_uri (Uri): URI to the resource configuration file.
command_str (str): The command to run in the container (will be split on spaces).
Expand Down
33 changes: 33 additions & 0 deletions python/tests/unit/src/common/env_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os
import unittest
from unittest import mock

from gigl.env.distributed import GIGL_COMPONENT_ENV_KEY
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.env import get_component


class TestGetComponent(unittest.TestCase):
"""Test suite for get_component function."""

@mock.patch.dict(os.environ, {GIGL_COMPONENT_ENV_KEY: GiGLComponents.Trainer.value})
def test_get_component_valid_value(self):
"""Test get_component returns correct component when env var is valid."""
result = get_component()
self.assertEqual(result, GiGLComponents.Trainer)

@mock.patch.dict(os.environ, {GIGL_COMPONENT_ENV_KEY: "invalid_component"})
def test_get_component_invalid_value(self):
"""Test get_component raises ValueError when env var is invalid."""
with self.assertRaises(ValueError):
get_component()

@mock.patch.dict(os.environ, {}, clear=True)
def test_get_component_not_set(self):
"""Test get_component raises KeyError when env var is not set."""
with self.assertRaises(KeyError):
get_component()


if __name__ == "__main__":
unittest.main()