Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions alphatrion/run/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,16 @@ async def train_model():
@staticmethod
def sync_status(run_id: uuid.UUID, result: Any) -> None:
"""
Sync function result to run status.
Sync function result to run status and optional status_msg to metadata.

Looks for 'status' key in result dict. Status can be a string representation,
or integer value.
or integer value. Optionally looks for 'status_msg' key to store in metadata.

Example:
async def train_model():
return {
"status": "COMPLETED" # or 9
"status": "COMPLETED", # or 9
"status_msg": "Training completed successfully"
}

run = exp.run(train_model, post_run_hooks=[
Expand All @@ -79,28 +80,42 @@ async def train_model():
return

status = None

# Extract status from dict
if isinstance(result, dict) and "status" in result:
status_value = result["status"]

if isinstance(status_value, str):
try:
status = Status[status_value.upper()]
except (KeyError, AttributeError):
logger.warning(
f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync."
)
return
elif isinstance(status_value, int):
try:
status = Status(status_value)
except ValueError:
logger.warning(
f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync."
)
return

if status is not None:
status_msg = None

# Extract status and status_msg from dict
if isinstance(result, dict):
if "status" in result:
status_value = result["status"]

if isinstance(status_value, str):
try:
status = Status[status_value.upper()]
except (KeyError, AttributeError):
logger.warning(
f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync."
)
return
elif isinstance(status_value, int):
try:
status = Status(status_value)
except ValueError:
logger.warning(
f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync."
)
return

if "status_msg" in result:
status_msg = result["status_msg"]

# Update both status and status_msg in a single call if needed
if status is not None or status_msg is not None:
metadb = global_runtime().metadb
metadb.update_run(run_id=run_id, status=status)
update_kwargs = {}

if status is not None:
update_kwargs["status"] = status

if status_msg is not None:
update_kwargs["meta"] = {"status_msg": status_msg}

metadb.update_run(run_id=run_id, **update_kwargs)
35 changes: 35 additions & 0 deletions tests/integration/test_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,38 @@ async def task_with_none_result():
# Metadata should be None or empty (hook didn't update it)
assert run_obj.meta is None or run_obj.meta == {}
assert run_obj.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_sync_status_with_status_msg(test_org_id, test_user_id, test_team_id):
"""Test sync_status hook with status_msg"""
alpha.init(org_id=test_org_id, team_id=test_team_id, user_id=test_user_id)

async def train_model():
await asyncio.sleep(0.1)
return {
"status": "COMPLETED",
"status_msg": "Training completed successfully with high accuracy",
}

async with CraftExperiment.start("test_status_msg") as exp:
run = exp.run(
train_model,
post_run_hooks=[PostRunHookFn.sync_metadata, PostRunHookFn.sync_status],
)
await exp.wait()

# Verify both hooks ran
metadb = global_runtime().metadb
run_obj = metadb.get_run(run_id=run.id)

# From sync_metadata hook
assert run_obj.meta["accuracy"] == 0.95
assert run_obj.meta["loss"] == 0.05

# From sync_status hook - status and status_msg
assert run_obj.status == Status.COMPLETED
assert (
run_obj.meta["status_msg"]
== "Training completed successfully with high accuracy"
)
42 changes: 42 additions & 0 deletions tests/unit/run/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,45 @@ def test_both_hooks_together(db):
assert run.meta["loss"] == 0.05
assert run.meta["num_epochs"] == 10
assert run.status == Status.COMPLETED


def test_sync_status_with_status_msg(db):
"""Test sync_status hook with status_msg"""
org_id = uuid.uuid4()
team_id = db.create_team(org_id=org_id, name="Test Team")
user_id = db.create_user(
org_id=org_id,
name="tester",
email="tester@example.com",
)
exp_id = db.create_experiment(
org_id=org_id,
team_id=team_id,
user_id=user_id,
name="test-exp",
)
run_id = db.create_run(
org_id=org_id,
team_id=team_id,
user_id=user_id,
experiment_id=exp_id,
)

# Mock result with status and status_msg
result = {
"status": "COMPLETED",
"status_msg": "Training completed successfully",
}

# Mock global_runtime
mock_runtime = Mock()
mock_runtime.metadb = db

with patch("alphatrion.run.hooks.global_runtime", return_value=mock_runtime):
# Call the hook
PostRunHookFn.sync_status(run_id, result)

# Verify status and status_msg were updated
run = db.get_run(run_id)
assert run.status == Status.COMPLETED
assert run.meta["status_msg"] == "Training completed successfully"
Loading