From 884c1d12d294af0ad3d8813eec82aa4f2823b5dc Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 18 May 2026 22:43:12 +0100 Subject: [PATCH 1/3] add support for pull checkpoints from s3 Signed-off-by: kerthcet --- alphatrion/artifact/oci_backend.py | 14 +- alphatrion/artifact/s3_backend.py | 107 ++++++++- alphatrion/log/load.py | 40 ++++ alphatrion/log/log.py | 4 +- tests/integration/test_oci_backend.py | 301 +++++++++++++++++++++++++ tests/unit/artifact/test_s3_backend.py | 235 ++++++++++++++++++- 6 files changed, 679 insertions(+), 22 deletions(-) create mode 100644 tests/integration/test_oci_backend.py diff --git a/alphatrion/artifact/oci_backend.py b/alphatrion/artifact/oci_backend.py index 4d08c692..06a3ff1b 100644 --- a/alphatrion/artifact/oci_backend.py +++ b/alphatrion/artifact/oci_backend.py @@ -76,22 +76,18 @@ def pull( ) -> list[str]: path = f"{repo_name}:{version}" target = f"{self._url}/{path}" - original_dir = None if output_dir: os.makedirs(output_dir, exist_ok=True) - original_dir = os.getcwd() - os.chdir(output_dir) + download_dir = os.path.abspath(output_dir) + else: + download_dir = os.getcwd() try: - filenames = self._client.pull(target, outdir="." if output_dir else None) - download_dir = os.getcwd() - return [os.path.abspath(os.path.join(download_dir, f)) for f in filenames] + filenames = self._client.pull(target, outdir=download_dir) + return [os.path.join(download_dir, f) for f in filenames] except Exception as e: raise RuntimeError(f"Failed to pull artifacts: {e}") from e - finally: - if output_dir and original_dir: - os.chdir(original_dir) def delete(self, repo_name: str, versions: str | list[str]): target = f"{self._url}/{repo_name}" diff --git a/alphatrion/artifact/s3_backend.py b/alphatrion/artifact/s3_backend.py index 8ec42618..c4509298 100644 --- a/alphatrion/artifact/s3_backend.py +++ b/alphatrion/artifact/s3_backend.py @@ -104,12 +104,115 @@ def push( return repo_name if version is None else f"{repo_name}/{version}" def list_versions(self, repo_name: str) -> list[str]: - raise NotImplementedError("list_versions is not implemented for S3 backend") + """List all files directly under a repository path (ignores nested files). + + Returns at most 3000 files (3 pages). If you have more checkpoints than this, + consider using database metadata to track versions instead. + + :param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt") + :return: List of filenames sorted by LastModified (newest first), max 3000 items + """ + try: + prefix = f"{repo_name}/" + files_with_time = [] + continuation_token = None + max_pages = 3 # Limit to 3000 files (1000 per page) + pages_fetched = 0 + + # Handle pagination for >1000 files, up to max_pages + while pages_fetched < max_pages: + # Use delimiter to only list top-level files, ignoring nested directories + params = { + "Bucket": self._bucket, + "Prefix": prefix, + "Delimiter": "/", + } + if continuation_token: + params["ContinuationToken"] = continuation_token + + response = self._s3.list_objects_v2(**params) + pages_fetched += 1 + + if "Contents" in response: + # Extract filenames and timestamps + for obj in response["Contents"]: + s3_key = obj["Key"] + # Get filename: "repo_name/file.txt" -> "file.txt" + filename = s3_key[len(prefix) :] + if filename: # Skip empty + files_with_time.append((filename, obj["LastModified"])) + + # Check if there are more results + if response.get("IsTruncated") and pages_fetched < max_pages: + continuation_token = response.get("NextContinuationToken") + else: + break + + if not files_with_time: + return [] + + # Sort by LastModified descending (newest first) + files_with_time.sort(key=lambda x: x[1], reverse=True) + + return [f[0] for f in files_with_time] + except Exception as e: + error_msg = str(e).lower() + if ( + "404" in error_msg + or "not found" in error_msg + or "nosuchbucket" in error_msg + ): + return [] + raise RuntimeError(f"Failed to list versions: {e}") from e def pull( self, repo_name: str, version: str, output_dir: str | None = None ) -> list[str]: - raise NotImplementedError("pull is not implemented for S3 backend") + """Pull (download) files from S3. + + :param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt") + :param version: The filename to download (for flat structure) or folder name (for versioned structure) + :param output_dir: Optional directory to save files. If None, downloads to current directory. + :return: List of absolute paths to downloaded files + """ + if output_dir: + os.makedirs(output_dir, exist_ok=True) + download_dir = os.path.abspath(output_dir) + else: + download_dir = os.getcwd() + + try: + # Check if version looks like a filename (has extension) or version folder + if "." in version: + # Single file: repo_name/version (e.g., "ckpt/checkpoint_123.pt") + s3_key = f"{repo_name}/{version}" + local_path = os.path.join(download_dir, version) + + self._s3.download_file(self._bucket, s3_key, local_path) + return [local_path] + else: + # Version folder: repo_name/version/* (e.g., "ckpt/v1/*") + prefix = f"{repo_name}/{version}/" + + response = self._s3.list_objects_v2( + Bucket=self._bucket, Prefix=prefix, Delimiter="/" + ) + + if "Contents" not in response: + return [] + + downloaded_files = [] + for obj in response["Contents"]: + s3_key = obj["Key"] + filename = s3_key[len(prefix) :] + if filename: # Skip empty/directory markers + local_path = os.path.join(download_dir, filename) + self._s3.download_file(self._bucket, s3_key, local_path) + downloaded_files.append(local_path) + + return downloaded_files + except Exception as e: + raise RuntimeError(f"Failed to pull artifacts from S3: {e}") from e def delete(self, repo_name: str, versions: str | list[str]): raise NotImplementedError("delete is not implemented for S3 backend") diff --git a/alphatrion/log/load.py b/alphatrion/log/load.py index c693e96e..90a603d9 100644 --- a/alphatrion/log/load.py +++ b/alphatrion/log/load.py @@ -25,3 +25,43 @@ async def load_dataset(id: str | uuid.UUID, output_dir: str | None = None) -> li ) return result + + +async def load_checkpoint( + experiment_id: str | uuid.UUID, + version: str = "latest", + output_dir: str | None = None, +) -> list[str]: + """ + Load checkpoint from artifact registry. + + :param experiment_id: the id of the experiment. + :param version: the version of the checkpoint to load, default is "latest". + For oci backend, version is the tag of the artifact. + For s3 backend, version is the name of the file to load. + If version is "latest", the most recently modified file will be loaded. + :param output_dir: the directory to which the checkpoint will be loaded. + """ + runtime = global_runtime() + + if isinstance(experiment_id, str): + experiment_id = uuid.UUID(experiment_id) + + artifact = runtime.artifact + if artifact is None: + raise RuntimeError("Artifact storage is not initialized in the runtime.") + + repo_name = f"{runtime.org_id}/{runtime.team_id}/{experiment_id}/ckpt" + + versions = artifact.list_versions(repo_name) + if versions is None or len(versions) == 0: + return [] + + if version == "latest": + version = versions[0] # Assuming versions are sorted by time, newest first + + result = await asyncio.get_running_loop().run_in_executor( + None, artifact.pull, repo_name, version, output_dir + ) + + return result diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index 2d20e0db..c56098ef 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -65,9 +65,7 @@ async def log_artifact( # Now validate that we have paths if not paths: # TODO: replace with logging library. - print( - "Warning: No paths provided for log_artifact. Nothing will be logged." - ) + print("Warning: No paths provided for log_artifact. Nothing will be logged.") # We should still run the post_save_hook even if there's nothing to log, # because the hook might have side effects that are important (e.g., cleanup). diff --git a/tests/integration/test_oci_backend.py b/tests/integration/test_oci_backend.py new file mode 100644 index 00000000..a7ebdbb7 --- /dev/null +++ b/tests/integration/test_oci_backend.py @@ -0,0 +1,301 @@ +"""Integration tests for OCI artifact backend. + +Note: These tests require a running OCI registry (like Docker Registry). + +To start the test services: + docker-compose -f docker-compose.test.yaml up -d registry + +Run tests with: + pytest tests/integration/test_oci_backend.py -v + +Cleanup: + docker-compose -f docker-compose.test.yaml down +""" + +import os +import tempfile +import uuid + +import pytest + + +@pytest.fixture(autouse=True) +def oci_env_vars(): + """Set up OCI environment variables for testing.""" + original_env = {} + env_vars = { + "ALPHATRION_ARTIFACT_STORAGE_TYPE": "oci", + "ALPHATRION_ARTIFACT_REGISTRY_URL": "localhost:25001", + "ALPHATRION_ENABLE_ARTIFACT_STORAGE": "true", + } + + for key, value in env_vars.items(): + original_env[key] = os.environ.get(key) + os.environ[key] = value + + yield + + # Restore original environment + for key, value in original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + +@pytest.fixture +def artifact(): + """Create an artifact instance with OCI backend.""" + from alphatrion.artifact.artifact import Artifact + + return Artifact(insecure=True) + + +@pytest.fixture +def unique_repo(): + """Generate a unique repository name for test isolation.""" + return f"org123/team456/test-{uuid.uuid4().hex[:8]}" + + +def test_oci_backend_push_single_file(artifact, unique_repo): + """Test OCI backend push with single file.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test file + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + + # Push artifact + path = artifact.push(repo_name=unique_repo, paths=test_file, version="v1") + assert path == f"{unique_repo}:v1" + + +def test_oci_backend_push_multiple_files(artifact, unique_repo): + """Test OCI backend push with multiple files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test files + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"file{i}.txt") + with open(file_path, "w") as f: + f.write(f"content {i}") + files.append(file_path) + + # Push multiple files + path = artifact.push(repo_name=unique_repo, paths=files, version="v2") + assert path == f"{unique_repo}:v2" + + +def test_oci_backend_push_folder(artifact, unique_repo): + """Test OCI backend push with folder.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test files in folder + test_dir = os.path.join(tmpdir, "test_folder") + os.makedirs(test_dir) + + for i in range(3): + with open(os.path.join(test_dir, f"file{i}.txt"), "w") as f: + f.write(f"content {i}") + + # Push folder + path = artifact.push(repo_name=unique_repo, paths=test_dir, version="v3") + assert path == f"{unique_repo}:v3" + + +def test_oci_backend_push_auto_version(artifact, unique_repo): + """Test OCI backend push with auto-generated version.""" + with tempfile.TemporaryDirectory() as tmpdir: + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + + # Push without version + path = artifact.push(repo_name=unique_repo, paths=test_file) + + # Should return repo_name:auto_version + assert path.startswith(f"{unique_repo}:") + + +def test_oci_backend_push_empty_files_error(artifact, unique_repo): + """Test OCI backend push with no files raises error.""" + with pytest.raises(ValueError, match="no files specified to push"): + artifact.push(repo_name=unique_repo, paths=None, version="v1") + + with pytest.raises(ValueError, match="no files specified to push"): + artifact.push(repo_name=unique_repo, paths="", version="v1") + + +def test_oci_backend_push_empty_folder_error(artifact, unique_repo): + """Test OCI backend push with empty folder raises error.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create empty folder + empty_dir = os.path.join(tmpdir, "empty") + os.makedirs(empty_dir) + + with pytest.raises(ValueError, match="No files to push"): + artifact.push(repo_name=unique_repo, paths=empty_dir, version="v1") + + +def test_oci_backend_list_versions(artifact, unique_repo): + """Test OCI backend list_versions.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple versions + for i in range(3): + test_file = os.path.join(tmpdir, f"test{i}.txt") + with open(test_file, "w") as f: + f.write(f"content {i}") + + artifact.push(repo_name=unique_repo, paths=test_file, version=f"v{i}") + + # List versions + versions = artifact.list_versions(unique_repo) + assert len(versions) == 3 + assert "v0" in versions + assert "v1" in versions + assert "v2" in versions + + +def test_oci_backend_list_versions_empty(artifact): + """Test list_versions returns empty list for non-existent repo.""" + versions = artifact.list_versions("org123/team456/nonexistent") + assert versions == [] + + +def test_oci_backend_pull_single_file(artifact, unique_repo): + """Test OCI backend pull.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push a file + test_file = os.path.join(tmpdir, "checkpoint.pt") + with open(test_file, "w") as f: + f.write("model weights") + + artifact.push(repo_name=unique_repo, paths=test_file, version="v1") + + # Pull the file + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull( + repo_name=unique_repo, version="v1", output_dir=output_dir + ) + + # Verify file was downloaded + assert len(result) == 1 + assert os.path.exists(result[0]) + assert os.path.basename(result[0]) == "checkpoint.pt" + + # Verify content + with open(result[0]) as f: + assert f.read() == "model weights" + + +def test_oci_backend_pull_multiple_files(artifact, unique_repo): + """Test pull with multiple files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple files + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"file{i}.txt") + with open(file_path, "w") as f: + f.write(f"content {i}") + files.append(file_path) + + artifact.push(repo_name=unique_repo, paths=files, version="v1") + + # Pull the files + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull(repo_name=unique_repo, version="v1", output_dir=output_dir) + + # Verify all files were downloaded + assert len(result) == 3 + + # Check that all expected files exist and have correct content + result_basenames = [os.path.basename(r) for r in result] + for i in range(3): + expected_filename = f"file{i}.txt" + assert expected_filename in result_basenames + + expected_file = os.path.join(output_dir, expected_filename) + assert os.path.exists(expected_file) + + with open(expected_file) as f: + assert f.read() == f"content {i}" + + +def test_oci_backend_pull_to_current_dir(artifact, unique_repo): + """Test pull without output_dir.""" + with tempfile.TemporaryDirectory() as tmpdir: + original_dir = os.getcwd() + try: + os.chdir(tmpdir) + + # Push a file + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + + artifact.push(repo_name=unique_repo, paths=test_file, version="v1") + + # Pull without output_dir + result = artifact.pull(repo_name=unique_repo, version="v1") + + # Should download to current directory + assert len(result) == 1 + assert os.path.basename(result[0]) == "test.txt" + assert os.path.exists(result[0]) + + with open(result[0]) as f: + assert f.read() == "test content" + finally: + os.chdir(original_dir) + + +def test_oci_backend_delete(artifact, unique_repo): + """Test OCI backend delete.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple versions with different content + test_file1 = os.path.join(tmpdir, "test1.txt") + with open(test_file1, "w") as f: + f.write("test content v1") + artifact.push(repo_name=unique_repo, paths=test_file1, version="v1") + + test_file2 = os.path.join(tmpdir, "test2.txt") + with open(test_file2, "w") as f: + f.write("test content v2") + artifact.push(repo_name=unique_repo, paths=test_file2, version="v2") + + # Verify both versions exist + versions = artifact.list_versions(unique_repo) + assert "v1" in versions + assert "v2" in versions + + # Delete v1 + artifact.delete(repo_name=unique_repo, versions="v1") + + # Verify v1 is deleted + versions = artifact.list_versions(unique_repo) + assert "v1" not in versions + assert "v2" in versions + + +def test_oci_backend_delete_multiple_versions(artifact, unique_repo): + """Test deleting multiple versions at once.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple versions with DIFFERENT content (so they have different blobs) + for i in range(3): + test_file = os.path.join(tmpdir, f"test{i}.txt") + with open(test_file, "w") as f: + f.write(f"test content version {i}") # Different content per version + + artifact.push(repo_name=unique_repo, paths=test_file, version=f"v{i}") + + # Verify all versions exist before delete + versions_before = artifact.list_versions(unique_repo) + assert len(versions_before) == 3 + + # Delete v0 and v1 + artifact.delete(repo_name=unique_repo, versions=["v0", "v1"]) + + # Verify only v2 remains + versions = artifact.list_versions(unique_repo) + assert "v0" not in versions + assert "v1" not in versions + assert "v2" in versions diff --git a/tests/unit/artifact/test_s3_backend.py b/tests/unit/artifact/test_s3_backend.py index 53bb4e96..218bb157 100644 --- a/tests/unit/artifact/test_s3_backend.py +++ b/tests/unit/artifact/test_s3_backend.py @@ -191,24 +191,243 @@ def test_s3_backend_push_empty_folder_error(s3_client): ) -def test_s3_backend_list_versions_not_implemented(s3_client): - """Test that list_versions raises NotImplementedError for S3 backend.""" +def test_s3_backend_list_versions_empty(s3_client): + """Test list_versions returns empty list for non-existent repo.""" from alphatrion.artifact.artifact import Artifact artifact = Artifact() - with pytest.raises(NotImplementedError, match="list_versions is not implemented"): - artifact.list_versions("org123/team456/test-repo") + versions = artifact.list_versions("org123/team456/nonexistent") + assert versions == [] -def test_s3_backend_pull_not_implemented(s3_client): - """Test that pull raises NotImplementedError.""" +def test_s3_backend_list_versions_single_file(s3_client): + """Test list_versions with a single file.""" from alphatrion.artifact.artifact import Artifact artifact = Artifact() - with pytest.raises(NotImplementedError, match="pull is not implemented"): - artifact.pull(repo_name="org123/team456/test-repo", version="v1") + with tempfile.TemporaryDirectory() as tmpdir: + test_file = os.path.join(tmpdir, "checkpoint.pt") + with open(test_file, "w") as f: + f.write("model weights") + + # Push file directly (no version folder) + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=test_file) + + # List versions should return the filename + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 1 + assert "checkpoint.pt" in versions + + +def test_s3_backend_list_versions_multiple_files(s3_client): + """Test list_versions with multiple files sorted by time.""" + import time + + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push 3 files with small delays to ensure different timestamps + for i in range(3): + test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") + with open(test_file, "w") as f: + f.write(f"model weights {i}") + + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=test_file) + time.sleep(1) # Small delay to ensure different timestamps + + # List versions should return files sorted by LastModified (newest first) + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 3 + # The newest file (checkpoint_2.pt) should be first + assert versions[0] == "checkpoint_2.pt" + assert versions[1] == "checkpoint_1.pt" + assert versions[2] == "checkpoint_0.pt" + + +def test_s3_backend_list_versions_ignores_nested(s3_client): + """Test list_versions ignores nested files (uses delimiter).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push top-level file + top_file = os.path.join(tmpdir, "checkpoint.pt") + with open(top_file, "w") as f: + f.write("top level") + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=top_file) + + # Manually create nested file in S3 (simulating accidental nested upload) + s3_client.put_object( + Bucket="test-bucket", + Key="org123/team456/exp1/ckpt/nested/file.txt", + Body=b"nested content", + ) + + # List versions should only return top-level file + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 1 + assert versions[0] == "checkpoint.pt" + + +def test_s3_backend_list_versions_pagination_limit(s3_client): + """Test list_versions respects 3000 file limit (3 pages).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Create 10 test files (simulating pagination scenario) + # In real scenario, we'd create 3500 files but that's slow for tests + files = [] + for i in range(10): + test_file = os.path.join(tmpdir, f"checkpoint_{i:04d}.pt") + with open(test_file, "w") as f: + f.write(f"model {i}") + files.append(test_file) + + # Push all files + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=files) + + # List versions should return all files (under the 3000 limit) + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 10 + # Should be sorted by timestamp (newest first) + assert all(f"checkpoint_{i:04d}.pt" in versions for i in range(10)) + + +def test_s3_backend_pull_single_file(s3_client): + """Test pull with single file (flat structure).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push a file + test_file = os.path.join(tmpdir, "checkpoint.pt") + with open(test_file, "w") as f: + f.write("model weights") + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=test_file) + + # Pull the file to a new directory + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull( + repo_name="org123/team456/exp1/ckpt", + version="checkpoint.pt", + output_dir=output_dir, + ) + + # Verify file was downloaded + assert len(result) == 1 + assert os.path.exists(result[0]) + assert os.path.basename(result[0]) == "checkpoint.pt" + + # Verify content + with open(result[0]) as f: + assert f.read() == "model weights" + + +def test_s3_backend_pull_version_folder(s3_client): + """Test pull with version folder (versioned structure).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple files with version + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"file{i}.txt") + with open(file_path, "w") as f: + f.write(f"content {i}") + files.append(file_path) + + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=files, version="v1") + + # Pull the version folder + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull( + repo_name="org123/team456/exp1/ckpt", version="v1", output_dir=output_dir + ) + + # Verify all files were downloaded + assert len(result) == 3 + for i in range(3): + expected_file = os.path.join(output_dir, f"file{i}.txt") + assert any(expected_file == r for r in result) + assert os.path.exists(expected_file) + + with open(expected_file) as f: + assert f.read() == f"content {i}" + + +def test_s3_backend_pull_to_current_dir(s3_client): + """Test pull without output_dir (downloads to current directory).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Change to temp directory + original_dir = os.getcwd() + try: + os.chdir(tmpdir) + + # Push a file + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + artifact.push(repo_name="org123/team456/test-repo", paths=test_file) + + # Pull without output_dir + result = artifact.pull( + repo_name="org123/team456/test-repo", version="test.txt" + ) + + # Should download to current directory + assert len(result) == 1 + assert os.path.basename(result[0]) == "test.txt" + assert os.path.exists(result[0]) + + with open(result[0]) as f: + assert f.read() == "test content" + finally: + os.chdir(original_dir) + + +def test_s3_backend_pull_nonexistent_file(s3_client): + """Test pull with non-existent file raises error.""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(RuntimeError, match="Failed to pull artifacts"): + artifact.pull( + repo_name="org123/team456/nonexistent", + version="missing.txt", + output_dir=tmpdir, + ) + + +def test_s3_backend_pull_empty_version_folder(s3_client): + """Test pull with empty version folder returns empty list.""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Pull non-existent version folder + result = artifact.pull( + repo_name="org123/team456/exp1/ckpt", version="v999", output_dir=tmpdir + ) + + # Should return empty list + assert result == [] def test_s3_backend_path_based_versioning(s3_client): From bf80b9053600dcb52566adb20f4c79aaa3a26f49 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Mon, 18 May 2026 23:14:07 +0100 Subject: [PATCH 2/3] support load_checkpoint api Signed-off-by: kerthcet --- alphatrion/__init__.py | 3 +- alphatrion/log/load.py | 12 +- tests/integration/test_oci_backend.py | 173 +++++++++++++++++++++++++- 3 files changed, 181 insertions(+), 7 deletions(-) diff --git a/alphatrion/__init__.py b/alphatrion/__init__.py index 9f8bce13..2c2b303a 100644 --- a/alphatrion/__init__.py +++ b/alphatrion/__init__.py @@ -1,4 +1,4 @@ -from alphatrion.log.load import load_dataset +from alphatrion.log.load import load_checkpoint, load_dataset from alphatrion.log.log import log_artifact, log_dataset, log_metrics, log_params from alphatrion.runtime.runtime import init @@ -9,4 +9,5 @@ "log_metrics", "log_dataset", "load_dataset", + "load_checkpoint", ] diff --git a/alphatrion/log/load.py b/alphatrion/log/load.py index 90a603d9..6462e846 100644 --- a/alphatrion/log/load.py +++ b/alphatrion/log/load.py @@ -28,30 +28,32 @@ async def load_dataset(id: str | uuid.UUID, output_dir: str | None = None) -> li async def load_checkpoint( - experiment_id: str | uuid.UUID, + id: str | uuid.UUID, version: str = "latest", + type: str = "experiment", output_dir: str | None = None, ) -> list[str]: """ Load checkpoint from artifact registry. - :param experiment_id: the id of the experiment. + :param id: the id of the experiment. :param version: the version of the checkpoint to load, default is "latest". For oci backend, version is the tag of the artifact. For s3 backend, version is the name of the file to load. If version is "latest", the most recently modified file will be loaded. + :param type: the type of the checkpoint, can be "experiment" or "agent", default is "experiment". :param output_dir: the directory to which the checkpoint will be loaded. """ runtime = global_runtime() - if isinstance(experiment_id, str): - experiment_id = uuid.UUID(experiment_id) + if isinstance(id, str): + id = uuid.UUID(id) artifact = runtime.artifact if artifact is None: raise RuntimeError("Artifact storage is not initialized in the runtime.") - repo_name = f"{runtime.org_id}/{runtime.team_id}/{experiment_id}/ckpt" + repo_name = f"{runtime.org_id}/{runtime.team_id}/{id}/ckpt" versions = artifact.list_versions(repo_name) if versions is None or len(versions) == 0: diff --git a/tests/integration/test_oci_backend.py b/tests/integration/test_oci_backend.py index a7ebdbb7..b89b19d0 100644 --- a/tests/integration/test_oci_backend.py +++ b/tests/integration/test_oci_backend.py @@ -18,6 +18,8 @@ import pytest +import alphatrion as alpha + @pytest.fixture(autouse=True) def oci_env_vars(): @@ -202,7 +204,9 @@ def test_oci_backend_pull_multiple_files(artifact, unique_repo): # Pull the files output_dir = os.path.join(tmpdir, "download") - result = artifact.pull(repo_name=unique_repo, version="v1", output_dir=output_dir) + result = artifact.pull( + repo_name=unique_repo, version="v1", output_dir=output_dir + ) # Verify all files were downloaded assert len(result) == 3 @@ -299,3 +303,170 @@ def test_oci_backend_delete_multiple_versions(artifact, unique_repo): assert "v0" not in versions assert "v1" not in versions assert "v2" in versions + + +@pytest.mark.asyncio +async def test_load_checkpoint_latest(artifact): + """Test load_checkpoint with 'latest' tag.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple checkpoint versions with different content and timestamps + import time + + for i in range(3): + test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") + with open(test_file, "w") as f: + f.write(f"model weights version {i}") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=test_file, + version=f"v{i}", + ) + if i < 2: + time.sleep(0.1) # Small delay for timestamp ordering + + # Load latest checkpoint (should be v2 for S3, but arbitrary for OCI) + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version="latest", output_dir=output_dir + ) + + # Verify checkpoint was downloaded + assert result is not None + assert len(result) == 1 + assert os.path.exists(result[0]) + + # Verify it's one of the versions + with open(result[0]) as f: + content = f.read() + assert content.startswith("model weights version") + + +@pytest.mark.asyncio +async def test_load_checkpoint_specific_version(artifact): + """Test load_checkpoint with specific version tag.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple checkpoint versions + for i in range(3): + test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") + with open(test_file, "w") as f: + f.write(f"model weights version {i}") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=test_file, + version=f"v{i}", + ) + + # Load specific version v1 + output_dir = os.path.join(tmpdir, "download") + + # Verify output_dir doesn't exist yet + assert not os.path.exists(output_dir), "Output dir should not exist before load_checkpoint" + + result = await alpha.load_checkpoint( + id=exp_id, version="v1", output_dir=output_dir + ) + + # Validate output_dir was created + assert os.path.exists(output_dir), "Output dir should be created by load_checkpoint" + assert os.path.isdir(output_dir), "Output path should be a directory" + + # Validate results + assert result is not None + assert len(result) == 1 + + # Validate file is in the correct output directory + downloaded_file = result[0] + # Use realpath to resolve symlinks (e.g., /var -> /private/var on macOS) + real_downloaded = os.path.realpath(downloaded_file) + real_output_dir = os.path.realpath(output_dir) + assert real_downloaded.startswith(real_output_dir), \ + f"File {real_downloaded} should be in output_dir {real_output_dir}" + + # Verify the file actually exists in output_dir + filename = os.path.basename(downloaded_file) + expected_path = os.path.join(output_dir, filename) + assert os.path.exists(expected_path), f"File should exist at {expected_path}" + + # Verify it's the correct version + with open(result[0]) as f: + content = f.read() + assert content == "model weights version 1" + + +@pytest.mark.asyncio +async def test_load_checkpoint_nonexistent(artifact): + """Test load_checkpoint returns None for nonexistent experiment.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Try to load checkpoint from non-existent experiment + result = await alpha.load_checkpoint( + id=exp_id, version="latest", output_dir=tmpdir + ) + + assert result == [] + + +@pytest.mark.asyncio +async def test_load_checkpoint_multiple_files(artifact): + """Test load_checkpoint with multiple files in checkpoint.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push checkpoint with multiple files + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"layer_{i}.pt") + with open(file_path, "w") as f: + f.write(f"layer {i} weights") + files.append(file_path) + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", paths=files, version="v1" + ) + + # Load checkpoint + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version="v1", output_dir=output_dir + ) + + # Verify all files were downloaded + assert result is not None + assert len(result) == 3 + + for i in range(3): + filename = f"layer_{i}.pt" + assert any(filename in r for r in result) + + file_path = os.path.join(output_dir, filename) + assert os.path.exists(file_path) + + with open(file_path) as f: + assert f.read() == f"layer {i} weights" From 2197c88f4d71c09017917c2940a8d9627c256abb Mon Sep 17 00:00:00 2001 From: kerthcet Date: Tue, 19 May 2026 00:50:54 +0100 Subject: [PATCH 3/3] fix lint Signed-off-by: kerthcet --- tests/integration/test_oci_backend.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/integration/test_oci_backend.py b/tests/integration/test_oci_backend.py index b89b19d0..9b137e2b 100644 --- a/tests/integration/test_oci_backend.py +++ b/tests/integration/test_oci_backend.py @@ -376,14 +376,18 @@ async def test_load_checkpoint_specific_version(artifact): output_dir = os.path.join(tmpdir, "download") # Verify output_dir doesn't exist yet - assert not os.path.exists(output_dir), "Output dir should not exist before load_checkpoint" + assert not os.path.exists(output_dir), ( + "Output dir should not exist before load_checkpoint" + ) result = await alpha.load_checkpoint( id=exp_id, version="v1", output_dir=output_dir ) # Validate output_dir was created - assert os.path.exists(output_dir), "Output dir should be created by load_checkpoint" + assert os.path.exists(output_dir), ( + "Output dir should be created by load_checkpoint" + ) assert os.path.isdir(output_dir), "Output path should be a directory" # Validate results @@ -395,8 +399,9 @@ async def test_load_checkpoint_specific_version(artifact): # Use realpath to resolve symlinks (e.g., /var -> /private/var on macOS) real_downloaded = os.path.realpath(downloaded_file) real_output_dir = os.path.realpath(output_dir) - assert real_downloaded.startswith(real_output_dir), \ + assert real_downloaded.startswith(real_output_dir), ( f"File {real_downloaded} should be in output_dir {real_output_dir}" + ) # Verify the file actually exists in output_dir filename = os.path.basename(downloaded_file)