Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,26 +154,15 @@ async def log_metrics(metrics: dict[str, float]):
should_early_stop = False
should_stop_on_target = False
for key, value in metrics.items():
if not isinstance(value, (int, float)):
# TODO: replace with logging library.
print(
f"Warning: Metric '{key}' has non-numeric value '{value}' and will be skipped for best metric tracking."
)
continue

float_value = float(value)

# Always call the should_checkpoint_on_best first because
# it also updates the best metric.
should_checkpoint |= exp.should_checkpoint_on_best(
metric_key=key, metric_value=float_value
metric_key=key, metric_value=value
)

should_early_stop |= exp.should_early_stop(
metric_key=key, metric_value=float_value
)
should_early_stop |= exp.should_early_stop(metric_key=key, metric_value=value)
should_stop_on_target |= exp.should_stop_on_target_metric(
metric_key=key, metric_value=float_value
metric_key=key, metric_value=value
)

if should_checkpoint:
Expand Down
4 changes: 1 addition & 3 deletions tests/integration/test_oci_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,7 @@ async def test_load_checkpoint_nonexistent(artifact):
# For OCI, trying to pull a non-existent tag should raise an error
# (unlike S3 which returns [] when no files exist)
with pytest.raises(RuntimeError, match="Failed to pull artifacts"):
await alpha.load_checkpoint(
id=exp_id, version="latest", output_dir=tmpdir
)
await alpha.load_checkpoint(id=exp_id, version="latest", output_dir=tmpdir)


@pytest.mark.asyncio
Expand Down
8 changes: 6 additions & 2 deletions tests/unit/artifact/test_s3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def test_s3_backend_pull_version_folder(s3_client):
# Pull the version folder
output_dir = os.path.join(tmpdir, "download")
result = artifact.pull(
repo_name="org123/team456/exp1/ckpt", version_or_filename="v1", output_dir=output_dir
repo_name="org123/team456/exp1/ckpt",
version_or_filename="v1",
output_dir=output_dir,
)

# Verify all files were downloaded
Expand Down Expand Up @@ -423,7 +425,9 @@ def test_s3_backend_pull_empty_version_folder(s3_client):
with tempfile.TemporaryDirectory() as tmpdir:
# Pull non-existent version folder
result = artifact.pull(
repo_name="org123/team456/exp1/ckpt", version_or_filename="v999", output_dir=tmpdir
repo_name="org123/team456/exp1/ckpt",
version_or_filename="v999",
output_dir=tmpdir,
)

# Should return empty list
Expand Down
1 change: 1 addition & 0 deletions tests/unit/experiment/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_config(self):
),
)


@pytest.mark.asyncio
async def test_experiment_with_done():
init(
Expand Down
Loading