diff --git a/alphatrion/run/hooks.py b/alphatrion/run/hooks.py index befd8d9..dc6dee7 100644 --- a/alphatrion/run/hooks.py +++ b/alphatrion/run/hooks.py @@ -57,15 +57,16 @@ async def train_model(): @staticmethod def sync_status(run_id: uuid.UUID, result: Any) -> None: """ - Sync function result to run status. + Sync function result to run status and optional status_msg to metadata. Looks for 'status' key in result dict. Status can be a string representation, - or integer value. + or integer value. Optionally looks for 'status_msg' key to store in metadata. Example: async def train_model(): return { - "status": "COMPLETED" # or 9 + "status": "COMPLETED", # or 9 + "status_msg": "Training completed successfully" } run = exp.run(train_model, post_run_hooks=[ @@ -79,28 +80,42 @@ async def train_model(): return status = None - - # Extract status from dict - if isinstance(result, dict) and "status" in result: - status_value = result["status"] - - if isinstance(status_value, str): - try: - status = Status[status_value.upper()] - except (KeyError, AttributeError): - logger.warning( - f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync." - ) - return - elif isinstance(status_value, int): - try: - status = Status(status_value) - except ValueError: - logger.warning( - f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync." - ) - return - - if status is not None: + status_msg = None + + # Extract status and status_msg from dict + if isinstance(result, dict): + if "status" in result: + status_value = result["status"] + + if isinstance(status_value, str): + try: + status = Status[status_value.upper()] + except (KeyError, AttributeError): + logger.warning( + f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync." + ) + return + elif isinstance(status_value, int): + try: + status = Status(status_value) + except ValueError: + logger.warning( + f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync." + ) + return + + if "status_msg" in result: + status_msg = result["status_msg"] + + # Update both status and status_msg in a single call if needed + if status is not None or status_msg is not None: metadb = global_runtime().metadb - metadb.update_run(run_id=run_id, status=status) + update_kwargs = {} + + if status is not None: + update_kwargs["status"] = status + + if status_msg is not None: + update_kwargs["meta"] = {"status_msg": status_msg} + + metadb.update_run(run_id=run_id, **update_kwargs) diff --git a/tests/integration/test_run_hooks.py b/tests/integration/test_run_hooks.py index b0d5b9a..ce6b214 100644 --- a/tests/integration/test_run_hooks.py +++ b/tests/integration/test_run_hooks.py @@ -306,3 +306,38 @@ async def task_with_none_result(): # Metadata should be None or empty (hook didn't update it) assert run_obj.meta is None or run_obj.meta == {} assert run_obj.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_sync_status_with_status_msg(test_org_id, test_user_id, test_team_id): + """Test sync_status hook with status_msg""" + alpha.init(org_id=test_org_id, team_id=test_team_id, user_id=test_user_id) + + async def train_model(): + await asyncio.sleep(0.1) + return { + "status": "COMPLETED", + "status_msg": "Training completed successfully with high accuracy", + } + + async with CraftExperiment.start("test_status_msg") as exp: + run = exp.run( + train_model, + post_run_hooks=[PostRunHookFn.sync_metadata, PostRunHookFn.sync_status], + ) + await exp.wait() + + # Verify both hooks ran + metadb = global_runtime().metadb + run_obj = metadb.get_run(run_id=run.id) + + # From sync_metadata hook + assert run_obj.meta["accuracy"] == 0.95 + assert run_obj.meta["loss"] == 0.05 + + # From sync_status hook - status and status_msg + assert run_obj.status == Status.COMPLETED + assert ( + run_obj.meta["status_msg"] + == "Training completed successfully with high accuracy" + ) diff --git a/tests/unit/run/test_hooks.py b/tests/unit/run/test_hooks.py index 3bee3d3..dfdbac9 100644 --- a/tests/unit/run/test_hooks.py +++ b/tests/unit/run/test_hooks.py @@ -455,3 +455,45 @@ def test_both_hooks_together(db): assert run.meta["loss"] == 0.05 assert run.meta["num_epochs"] == 10 assert run.status == Status.COMPLETED + + +def test_sync_status_with_status_msg(db): + """Test sync_status hook with status_msg""" + org_id = uuid.uuid4() + team_id = db.create_team(org_id=org_id, name="Test Team") + user_id = db.create_user( + org_id=org_id, + name="tester", + email="tester@example.com", + ) + exp_id = db.create_experiment( + org_id=org_id, + team_id=team_id, + user_id=user_id, + name="test-exp", + ) + run_id = db.create_run( + org_id=org_id, + team_id=team_id, + user_id=user_id, + experiment_id=exp_id, + ) + + # Mock result with status and status_msg + result = { + "status": "COMPLETED", + "status_msg": "Training completed successfully", + } + + # Mock global_runtime + mock_runtime = Mock() + mock_runtime.metadb = db + + with patch("alphatrion.run.hooks.global_runtime", return_value=mock_runtime): + # Call the hook + PostRunHookFn.sync_status(run_id, result) + + # Verify status and status_msg were updated + run = db.get_run(run_id) + assert run.status == Status.COMPLETED + assert run.meta["status_msg"] == "Training completed successfully"