diff --git a/alphatrion/artifact/artifact.py b/alphatrion/artifact/artifact.py index ea6ed46..bdaa463 100644 --- a/alphatrion/artifact/artifact.py +++ b/alphatrion/artifact/artifact.py @@ -6,27 +6,36 @@ SUCCESS_CODE = 201 +ARTIFACT_TYPE_S3 = "s3" +ARTIFACT_TYPE_OCI = "oci" + class Artifact: """Artifact storage client with pluggable backends (OCI or S3).""" def __init__(self, insecure: bool = False): - storage_type = os.environ.get(envs.ARTIFACT_STORAGE_TYPE, "oci").lower() + self._storage_type = os.environ.get( + envs.ARTIFACT_STORAGE_TYPE, ARTIFACT_TYPE_OCI + ).lower() - if storage_type == "s3": + if self._storage_type == ARTIFACT_TYPE_S3: from alphatrion.artifact.s3_backend import S3Backend self._backend = S3Backend() - elif storage_type == "oci": + elif self._storage_type == ARTIFACT_TYPE_OCI: from alphatrion.artifact.oci_backend import OCIBackend self._backend = OCIBackend(insecure=insecure) else: raise ValueError( - f"Unsupported artifact storage type: {storage_type}. " - f"Supported types: 'oci', 's3'" + f"Unsupported artifact storage type: {self._storage_type}. " + f"Supported types: '{ARTIFACT_TYPE_OCI}', '{ARTIFACT_TYPE_S3}'" ) + @property + def storage_type(self): + return self._storage_type + def push( self, repo_name: str, diff --git a/alphatrion/log/load.py b/alphatrion/log/load.py index 6462e84..4b64eb6 100644 --- a/alphatrion/log/load.py +++ b/alphatrion/log/load.py @@ -1,6 +1,7 @@ import asyncio import uuid +from alphatrion.artifact.artifact import ARTIFACT_TYPE_S3 from alphatrion.runtime.runtime import global_runtime @@ -55,11 +56,13 @@ async def load_checkpoint( repo_name = f"{runtime.org_id}/{runtime.team_id}/{id}/ckpt" - versions = artifact.list_versions(repo_name) - if versions is None or len(versions) == 0: - return [] + # We only need to do this for s3 backend, because for oci backend, + # the version is the tag and "latest" tag will always point to the latest version. + if version == "latest" and artifact.storage_type == ARTIFACT_TYPE_S3: + versions = artifact.list_versions(repo_name) + if versions is None or len(versions) == 0: + return [] - if version == "latest": version = versions[0] # Assuming versions are sorted by time, newest first result = await asyncio.get_running_loop().run_in_executor( diff --git a/tests/integration/test_oci_backend.py b/tests/integration/test_oci_backend.py index 9b137e2..9d215a6 100644 --- a/tests/integration/test_oci_backend.py +++ b/tests/integration/test_oci_backend.py @@ -316,9 +316,7 @@ async def test_load_checkpoint_latest(artifact): alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) with tempfile.TemporaryDirectory() as tmpdir: - # Push multiple checkpoint versions with different content and timestamps - import time - + # Push multiple checkpoint versions with different content for i in range(3): test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") with open(test_file, "w") as f: @@ -329,10 +327,19 @@ async def test_load_checkpoint_latest(artifact): paths=test_file, version=f"v{i}", ) - if i < 2: - time.sleep(0.1) # Small delay for timestamp ordering - # Load latest checkpoint (should be v2 for S3, but arbitrary for OCI) + # For OCI, also push a version tagged as "latest" + latest_file = os.path.join(tmpdir, "checkpoint_latest.pt") + with open(latest_file, "w") as f: + f.write("model weights version latest") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=latest_file, + version="latest", + ) + + # Load latest checkpoint output_dir = os.path.join(tmpdir, "download") result = await alpha.load_checkpoint( id=exp_id, version="latest", output_dir=output_dir @@ -343,10 +350,10 @@ async def test_load_checkpoint_latest(artifact): assert len(result) == 1 assert os.path.exists(result[0]) - # Verify it's one of the versions + # For OCI, "latest" tag should return the checkpoint tagged as "latest" with open(result[0]) as f: content = f.read() - assert content.startswith("model weights version") + assert content == "model weights version latest" @pytest.mark.asyncio @@ -416,7 +423,7 @@ async def test_load_checkpoint_specific_version(artifact): @pytest.mark.asyncio async def test_load_checkpoint_nonexistent(artifact): - """Test load_checkpoint returns None for nonexistent experiment.""" + """Test load_checkpoint with nonexistent checkpoint tag raises error for OCI.""" org_id = uuid.uuid4() team_id = uuid.uuid4() user_id = uuid.uuid4() @@ -425,12 +432,12 @@ async def test_load_checkpoint_nonexistent(artifact): alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) with tempfile.TemporaryDirectory() as tmpdir: - # Try to load checkpoint from non-existent experiment - result = await alpha.load_checkpoint( - id=exp_id, version="latest", output_dir=tmpdir - ) - - assert result == [] + # For OCI, trying to pull a non-existent tag should raise an error + # (unlike S3 which returns [] when no files exist) + with pytest.raises(RuntimeError, match="Failed to pull artifacts"): + await alpha.load_checkpoint( + id=exp_id, version="latest", output_dir=tmpdir + ) @pytest.mark.asyncio