diff --git a/alphatrion/__init__.py b/alphatrion/__init__.py index 9f8bce1..2c2b303 100644 --- a/alphatrion/__init__.py +++ b/alphatrion/__init__.py @@ -1,4 +1,4 @@ -from alphatrion.log.load import load_dataset +from alphatrion.log.load import load_checkpoint, load_dataset from alphatrion.log.log import log_artifact, log_dataset, log_metrics, log_params from alphatrion.runtime.runtime import init @@ -9,4 +9,5 @@ "log_metrics", "log_dataset", "load_dataset", + "load_checkpoint", ] diff --git a/alphatrion/artifact/oci_backend.py b/alphatrion/artifact/oci_backend.py index 4d08c69..06a3ff1 100644 --- a/alphatrion/artifact/oci_backend.py +++ b/alphatrion/artifact/oci_backend.py @@ -76,22 +76,18 @@ def pull( ) -> list[str]: path = f"{repo_name}:{version}" target = f"{self._url}/{path}" - original_dir = None if output_dir: os.makedirs(output_dir, exist_ok=True) - original_dir = os.getcwd() - os.chdir(output_dir) + download_dir = os.path.abspath(output_dir) + else: + download_dir = os.getcwd() try: - filenames = self._client.pull(target, outdir="." if output_dir else None) - download_dir = os.getcwd() - return [os.path.abspath(os.path.join(download_dir, f)) for f in filenames] + filenames = self._client.pull(target, outdir=download_dir) + return [os.path.join(download_dir, f) for f in filenames] except Exception as e: raise RuntimeError(f"Failed to pull artifacts: {e}") from e - finally: - if output_dir and original_dir: - os.chdir(original_dir) def delete(self, repo_name: str, versions: str | list[str]): target = f"{self._url}/{repo_name}" diff --git a/alphatrion/artifact/s3_backend.py b/alphatrion/artifact/s3_backend.py index 8ec4261..c450929 100644 --- a/alphatrion/artifact/s3_backend.py +++ b/alphatrion/artifact/s3_backend.py @@ -104,12 +104,115 @@ def push( return repo_name if version is None else f"{repo_name}/{version}" def list_versions(self, repo_name: str) -> list[str]: - raise NotImplementedError("list_versions is not implemented for S3 backend") + """List all files directly under a repository path (ignores nested files). + + Returns at most 3000 files (3 pages). If you have more checkpoints than this, + consider using database metadata to track versions instead. + + :param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt") + :return: List of filenames sorted by LastModified (newest first), max 3000 items + """ + try: + prefix = f"{repo_name}/" + files_with_time = [] + continuation_token = None + max_pages = 3 # Limit to 3000 files (1000 per page) + pages_fetched = 0 + + # Handle pagination for >1000 files, up to max_pages + while pages_fetched < max_pages: + # Use delimiter to only list top-level files, ignoring nested directories + params = { + "Bucket": self._bucket, + "Prefix": prefix, + "Delimiter": "/", + } + if continuation_token: + params["ContinuationToken"] = continuation_token + + response = self._s3.list_objects_v2(**params) + pages_fetched += 1 + + if "Contents" in response: + # Extract filenames and timestamps + for obj in response["Contents"]: + s3_key = obj["Key"] + # Get filename: "repo_name/file.txt" -> "file.txt" + filename = s3_key[len(prefix) :] + if filename: # Skip empty + files_with_time.append((filename, obj["LastModified"])) + + # Check if there are more results + if response.get("IsTruncated") and pages_fetched < max_pages: + continuation_token = response.get("NextContinuationToken") + else: + break + + if not files_with_time: + return [] + + # Sort by LastModified descending (newest first) + files_with_time.sort(key=lambda x: x[1], reverse=True) + + return [f[0] for f in files_with_time] + except Exception as e: + error_msg = str(e).lower() + if ( + "404" in error_msg + or "not found" in error_msg + or "nosuchbucket" in error_msg + ): + return [] + raise RuntimeError(f"Failed to list versions: {e}") from e def pull( self, repo_name: str, version: str, output_dir: str | None = None ) -> list[str]: - raise NotImplementedError("pull is not implemented for S3 backend") + """Pull (download) files from S3. + + :param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt") + :param version: The filename to download (for flat structure) or folder name (for versioned structure) + :param output_dir: Optional directory to save files. If None, downloads to current directory. + :return: List of absolute paths to downloaded files + """ + if output_dir: + os.makedirs(output_dir, exist_ok=True) + download_dir = os.path.abspath(output_dir) + else: + download_dir = os.getcwd() + + try: + # Check if version looks like a filename (has extension) or version folder + if "." in version: + # Single file: repo_name/version (e.g., "ckpt/checkpoint_123.pt") + s3_key = f"{repo_name}/{version}" + local_path = os.path.join(download_dir, version) + + self._s3.download_file(self._bucket, s3_key, local_path) + return [local_path] + else: + # Version folder: repo_name/version/* (e.g., "ckpt/v1/*") + prefix = f"{repo_name}/{version}/" + + response = self._s3.list_objects_v2( + Bucket=self._bucket, Prefix=prefix, Delimiter="/" + ) + + if "Contents" not in response: + return [] + + downloaded_files = [] + for obj in response["Contents"]: + s3_key = obj["Key"] + filename = s3_key[len(prefix) :] + if filename: # Skip empty/directory markers + local_path = os.path.join(download_dir, filename) + self._s3.download_file(self._bucket, s3_key, local_path) + downloaded_files.append(local_path) + + return downloaded_files + except Exception as e: + raise RuntimeError(f"Failed to pull artifacts from S3: {e}") from e def delete(self, repo_name: str, versions: str | list[str]): raise NotImplementedError("delete is not implemented for S3 backend") diff --git a/alphatrion/log/load.py b/alphatrion/log/load.py index c693e96..6462e84 100644 --- a/alphatrion/log/load.py +++ b/alphatrion/log/load.py @@ -25,3 +25,45 @@ async def load_dataset(id: str | uuid.UUID, output_dir: str | None = None) -> li ) return result + + +async def load_checkpoint( + id: str | uuid.UUID, + version: str = "latest", + type: str = "experiment", + output_dir: str | None = None, +) -> list[str]: + """ + Load checkpoint from artifact registry. + + :param id: the id of the experiment. + :param version: the version of the checkpoint to load, default is "latest". + For oci backend, version is the tag of the artifact. + For s3 backend, version is the name of the file to load. + If version is "latest", the most recently modified file will be loaded. + :param type: the type of the checkpoint, can be "experiment" or "agent", default is "experiment". + :param output_dir: the directory to which the checkpoint will be loaded. + """ + runtime = global_runtime() + + if isinstance(id, str): + id = uuid.UUID(id) + + artifact = runtime.artifact + if artifact is None: + raise RuntimeError("Artifact storage is not initialized in the runtime.") + + repo_name = f"{runtime.org_id}/{runtime.team_id}/{id}/ckpt" + + versions = artifact.list_versions(repo_name) + if versions is None or len(versions) == 0: + return [] + + if version == "latest": + version = versions[0] # Assuming versions are sorted by time, newest first + + result = await asyncio.get_running_loop().run_in_executor( + None, artifact.pull, repo_name, version, output_dir + ) + + return result diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index 2d20e0d..c56098e 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -65,9 +65,7 @@ async def log_artifact( # Now validate that we have paths if not paths: # TODO: replace with logging library. - print( - "Warning: No paths provided for log_artifact. Nothing will be logged." - ) + print("Warning: No paths provided for log_artifact. Nothing will be logged.") # We should still run the post_save_hook even if there's nothing to log, # because the hook might have side effects that are important (e.g., cleanup). diff --git a/tests/integration/test_oci_backend.py b/tests/integration/test_oci_backend.py new file mode 100644 index 0000000..9b137e2 --- /dev/null +++ b/tests/integration/test_oci_backend.py @@ -0,0 +1,477 @@ +"""Integration tests for OCI artifact backend. + +Note: These tests require a running OCI registry (like Docker Registry). + +To start the test services: + docker-compose -f docker-compose.test.yaml up -d registry + +Run tests with: + pytest tests/integration/test_oci_backend.py -v + +Cleanup: + docker-compose -f docker-compose.test.yaml down +""" + +import os +import tempfile +import uuid + +import pytest + +import alphatrion as alpha + + +@pytest.fixture(autouse=True) +def oci_env_vars(): + """Set up OCI environment variables for testing.""" + original_env = {} + env_vars = { + "ALPHATRION_ARTIFACT_STORAGE_TYPE": "oci", + "ALPHATRION_ARTIFACT_REGISTRY_URL": "localhost:25001", + "ALPHATRION_ENABLE_ARTIFACT_STORAGE": "true", + } + + for key, value in env_vars.items(): + original_env[key] = os.environ.get(key) + os.environ[key] = value + + yield + + # Restore original environment + for key, value in original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + +@pytest.fixture +def artifact(): + """Create an artifact instance with OCI backend.""" + from alphatrion.artifact.artifact import Artifact + + return Artifact(insecure=True) + + +@pytest.fixture +def unique_repo(): + """Generate a unique repository name for test isolation.""" + return f"org123/team456/test-{uuid.uuid4().hex[:8]}" + + +def test_oci_backend_push_single_file(artifact, unique_repo): + """Test OCI backend push with single file.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test file + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + + # Push artifact + path = artifact.push(repo_name=unique_repo, paths=test_file, version="v1") + assert path == f"{unique_repo}:v1" + + +def test_oci_backend_push_multiple_files(artifact, unique_repo): + """Test OCI backend push with multiple files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test files + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"file{i}.txt") + with open(file_path, "w") as f: + f.write(f"content {i}") + files.append(file_path) + + # Push multiple files + path = artifact.push(repo_name=unique_repo, paths=files, version="v2") + assert path == f"{unique_repo}:v2" + + +def test_oci_backend_push_folder(artifact, unique_repo): + """Test OCI backend push with folder.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create test files in folder + test_dir = os.path.join(tmpdir, "test_folder") + os.makedirs(test_dir) + + for i in range(3): + with open(os.path.join(test_dir, f"file{i}.txt"), "w") as f: + f.write(f"content {i}") + + # Push folder + path = artifact.push(repo_name=unique_repo, paths=test_dir, version="v3") + assert path == f"{unique_repo}:v3" + + +def test_oci_backend_push_auto_version(artifact, unique_repo): + """Test OCI backend push with auto-generated version.""" + with tempfile.TemporaryDirectory() as tmpdir: + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + + # Push without version + path = artifact.push(repo_name=unique_repo, paths=test_file) + + # Should return repo_name:auto_version + assert path.startswith(f"{unique_repo}:") + + +def test_oci_backend_push_empty_files_error(artifact, unique_repo): + """Test OCI backend push with no files raises error.""" + with pytest.raises(ValueError, match="no files specified to push"): + artifact.push(repo_name=unique_repo, paths=None, version="v1") + + with pytest.raises(ValueError, match="no files specified to push"): + artifact.push(repo_name=unique_repo, paths="", version="v1") + + +def test_oci_backend_push_empty_folder_error(artifact, unique_repo): + """Test OCI backend push with empty folder raises error.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create empty folder + empty_dir = os.path.join(tmpdir, "empty") + os.makedirs(empty_dir) + + with pytest.raises(ValueError, match="No files to push"): + artifact.push(repo_name=unique_repo, paths=empty_dir, version="v1") + + +def test_oci_backend_list_versions(artifact, unique_repo): + """Test OCI backend list_versions.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple versions + for i in range(3): + test_file = os.path.join(tmpdir, f"test{i}.txt") + with open(test_file, "w") as f: + f.write(f"content {i}") + + artifact.push(repo_name=unique_repo, paths=test_file, version=f"v{i}") + + # List versions + versions = artifact.list_versions(unique_repo) + assert len(versions) == 3 + assert "v0" in versions + assert "v1" in versions + assert "v2" in versions + + +def test_oci_backend_list_versions_empty(artifact): + """Test list_versions returns empty list for non-existent repo.""" + versions = artifact.list_versions("org123/team456/nonexistent") + assert versions == [] + + +def test_oci_backend_pull_single_file(artifact, unique_repo): + """Test OCI backend pull.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push a file + test_file = os.path.join(tmpdir, "checkpoint.pt") + with open(test_file, "w") as f: + f.write("model weights") + + artifact.push(repo_name=unique_repo, paths=test_file, version="v1") + + # Pull the file + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull( + repo_name=unique_repo, version="v1", output_dir=output_dir + ) + + # Verify file was downloaded + assert len(result) == 1 + assert os.path.exists(result[0]) + assert os.path.basename(result[0]) == "checkpoint.pt" + + # Verify content + with open(result[0]) as f: + assert f.read() == "model weights" + + +def test_oci_backend_pull_multiple_files(artifact, unique_repo): + """Test pull with multiple files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple files + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"file{i}.txt") + with open(file_path, "w") as f: + f.write(f"content {i}") + files.append(file_path) + + artifact.push(repo_name=unique_repo, paths=files, version="v1") + + # Pull the files + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull( + repo_name=unique_repo, version="v1", output_dir=output_dir + ) + + # Verify all files were downloaded + assert len(result) == 3 + + # Check that all expected files exist and have correct content + result_basenames = [os.path.basename(r) for r in result] + for i in range(3): + expected_filename = f"file{i}.txt" + assert expected_filename in result_basenames + + expected_file = os.path.join(output_dir, expected_filename) + assert os.path.exists(expected_file) + + with open(expected_file) as f: + assert f.read() == f"content {i}" + + +def test_oci_backend_pull_to_current_dir(artifact, unique_repo): + """Test pull without output_dir.""" + with tempfile.TemporaryDirectory() as tmpdir: + original_dir = os.getcwd() + try: + os.chdir(tmpdir) + + # Push a file + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + + artifact.push(repo_name=unique_repo, paths=test_file, version="v1") + + # Pull without output_dir + result = artifact.pull(repo_name=unique_repo, version="v1") + + # Should download to current directory + assert len(result) == 1 + assert os.path.basename(result[0]) == "test.txt" + assert os.path.exists(result[0]) + + with open(result[0]) as f: + assert f.read() == "test content" + finally: + os.chdir(original_dir) + + +def test_oci_backend_delete(artifact, unique_repo): + """Test OCI backend delete.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple versions with different content + test_file1 = os.path.join(tmpdir, "test1.txt") + with open(test_file1, "w") as f: + f.write("test content v1") + artifact.push(repo_name=unique_repo, paths=test_file1, version="v1") + + test_file2 = os.path.join(tmpdir, "test2.txt") + with open(test_file2, "w") as f: + f.write("test content v2") + artifact.push(repo_name=unique_repo, paths=test_file2, version="v2") + + # Verify both versions exist + versions = artifact.list_versions(unique_repo) + assert "v1" in versions + assert "v2" in versions + + # Delete v1 + artifact.delete(repo_name=unique_repo, versions="v1") + + # Verify v1 is deleted + versions = artifact.list_versions(unique_repo) + assert "v1" not in versions + assert "v2" in versions + + +def test_oci_backend_delete_multiple_versions(artifact, unique_repo): + """Test deleting multiple versions at once.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple versions with DIFFERENT content (so they have different blobs) + for i in range(3): + test_file = os.path.join(tmpdir, f"test{i}.txt") + with open(test_file, "w") as f: + f.write(f"test content version {i}") # Different content per version + + artifact.push(repo_name=unique_repo, paths=test_file, version=f"v{i}") + + # Verify all versions exist before delete + versions_before = artifact.list_versions(unique_repo) + assert len(versions_before) == 3 + + # Delete v0 and v1 + artifact.delete(repo_name=unique_repo, versions=["v0", "v1"]) + + # Verify only v2 remains + versions = artifact.list_versions(unique_repo) + assert "v0" not in versions + assert "v1" not in versions + assert "v2" in versions + + +@pytest.mark.asyncio +async def test_load_checkpoint_latest(artifact): + """Test load_checkpoint with 'latest' tag.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple checkpoint versions with different content and timestamps + import time + + for i in range(3): + test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") + with open(test_file, "w") as f: + f.write(f"model weights version {i}") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=test_file, + version=f"v{i}", + ) + if i < 2: + time.sleep(0.1) # Small delay for timestamp ordering + + # Load latest checkpoint (should be v2 for S3, but arbitrary for OCI) + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version="latest", output_dir=output_dir + ) + + # Verify checkpoint was downloaded + assert result is not None + assert len(result) == 1 + assert os.path.exists(result[0]) + + # Verify it's one of the versions + with open(result[0]) as f: + content = f.read() + assert content.startswith("model weights version") + + +@pytest.mark.asyncio +async def test_load_checkpoint_specific_version(artifact): + """Test load_checkpoint with specific version tag.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple checkpoint versions + for i in range(3): + test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") + with open(test_file, "w") as f: + f.write(f"model weights version {i}") + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", + paths=test_file, + version=f"v{i}", + ) + + # Load specific version v1 + output_dir = os.path.join(tmpdir, "download") + + # Verify output_dir doesn't exist yet + assert not os.path.exists(output_dir), ( + "Output dir should not exist before load_checkpoint" + ) + + result = await alpha.load_checkpoint( + id=exp_id, version="v1", output_dir=output_dir + ) + + # Validate output_dir was created + assert os.path.exists(output_dir), ( + "Output dir should be created by load_checkpoint" + ) + assert os.path.isdir(output_dir), "Output path should be a directory" + + # Validate results + assert result is not None + assert len(result) == 1 + + # Validate file is in the correct output directory + downloaded_file = result[0] + # Use realpath to resolve symlinks (e.g., /var -> /private/var on macOS) + real_downloaded = os.path.realpath(downloaded_file) + real_output_dir = os.path.realpath(output_dir) + assert real_downloaded.startswith(real_output_dir), ( + f"File {real_downloaded} should be in output_dir {real_output_dir}" + ) + + # Verify the file actually exists in output_dir + filename = os.path.basename(downloaded_file) + expected_path = os.path.join(output_dir, filename) + assert os.path.exists(expected_path), f"File should exist at {expected_path}" + + # Verify it's the correct version + with open(result[0]) as f: + content = f.read() + assert content == "model weights version 1" + + +@pytest.mark.asyncio +async def test_load_checkpoint_nonexistent(artifact): + """Test load_checkpoint returns None for nonexistent experiment.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Try to load checkpoint from non-existent experiment + result = await alpha.load_checkpoint( + id=exp_id, version="latest", output_dir=tmpdir + ) + + assert result == [] + + +@pytest.mark.asyncio +async def test_load_checkpoint_multiple_files(artifact): + """Test load_checkpoint with multiple files in checkpoint.""" + org_id = uuid.uuid4() + team_id = uuid.uuid4() + user_id = uuid.uuid4() + exp_id = uuid.uuid4() + + alpha.init(org_id=org_id, team_id=team_id, user_id=user_id) + + with tempfile.TemporaryDirectory() as tmpdir: + # Push checkpoint with multiple files + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"layer_{i}.pt") + with open(file_path, "w") as f: + f.write(f"layer {i} weights") + files.append(file_path) + + artifact.push( + repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt", paths=files, version="v1" + ) + + # Load checkpoint + output_dir = os.path.join(tmpdir, "download") + result = await alpha.load_checkpoint( + id=exp_id, version="v1", output_dir=output_dir + ) + + # Verify all files were downloaded + assert result is not None + assert len(result) == 3 + + for i in range(3): + filename = f"layer_{i}.pt" + assert any(filename in r for r in result) + + file_path = os.path.join(output_dir, filename) + assert os.path.exists(file_path) + + with open(file_path) as f: + assert f.read() == f"layer {i} weights" diff --git a/tests/unit/artifact/test_s3_backend.py b/tests/unit/artifact/test_s3_backend.py index 53bb4e9..218bb15 100644 --- a/tests/unit/artifact/test_s3_backend.py +++ b/tests/unit/artifact/test_s3_backend.py @@ -191,24 +191,243 @@ def test_s3_backend_push_empty_folder_error(s3_client): ) -def test_s3_backend_list_versions_not_implemented(s3_client): - """Test that list_versions raises NotImplementedError for S3 backend.""" +def test_s3_backend_list_versions_empty(s3_client): + """Test list_versions returns empty list for non-existent repo.""" from alphatrion.artifact.artifact import Artifact artifact = Artifact() - with pytest.raises(NotImplementedError, match="list_versions is not implemented"): - artifact.list_versions("org123/team456/test-repo") + versions = artifact.list_versions("org123/team456/nonexistent") + assert versions == [] -def test_s3_backend_pull_not_implemented(s3_client): - """Test that pull raises NotImplementedError.""" +def test_s3_backend_list_versions_single_file(s3_client): + """Test list_versions with a single file.""" from alphatrion.artifact.artifact import Artifact artifact = Artifact() - with pytest.raises(NotImplementedError, match="pull is not implemented"): - artifact.pull(repo_name="org123/team456/test-repo", version="v1") + with tempfile.TemporaryDirectory() as tmpdir: + test_file = os.path.join(tmpdir, "checkpoint.pt") + with open(test_file, "w") as f: + f.write("model weights") + + # Push file directly (no version folder) + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=test_file) + + # List versions should return the filename + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 1 + assert "checkpoint.pt" in versions + + +def test_s3_backend_list_versions_multiple_files(s3_client): + """Test list_versions with multiple files sorted by time.""" + import time + + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push 3 files with small delays to ensure different timestamps + for i in range(3): + test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt") + with open(test_file, "w") as f: + f.write(f"model weights {i}") + + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=test_file) + time.sleep(1) # Small delay to ensure different timestamps + + # List versions should return files sorted by LastModified (newest first) + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 3 + # The newest file (checkpoint_2.pt) should be first + assert versions[0] == "checkpoint_2.pt" + assert versions[1] == "checkpoint_1.pt" + assert versions[2] == "checkpoint_0.pt" + + +def test_s3_backend_list_versions_ignores_nested(s3_client): + """Test list_versions ignores nested files (uses delimiter).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push top-level file + top_file = os.path.join(tmpdir, "checkpoint.pt") + with open(top_file, "w") as f: + f.write("top level") + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=top_file) + + # Manually create nested file in S3 (simulating accidental nested upload) + s3_client.put_object( + Bucket="test-bucket", + Key="org123/team456/exp1/ckpt/nested/file.txt", + Body=b"nested content", + ) + + # List versions should only return top-level file + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 1 + assert versions[0] == "checkpoint.pt" + + +def test_s3_backend_list_versions_pagination_limit(s3_client): + """Test list_versions respects 3000 file limit (3 pages).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Create 10 test files (simulating pagination scenario) + # In real scenario, we'd create 3500 files but that's slow for tests + files = [] + for i in range(10): + test_file = os.path.join(tmpdir, f"checkpoint_{i:04d}.pt") + with open(test_file, "w") as f: + f.write(f"model {i}") + files.append(test_file) + + # Push all files + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=files) + + # List versions should return all files (under the 3000 limit) + versions = artifact.list_versions("org123/team456/exp1/ckpt") + assert len(versions) == 10 + # Should be sorted by timestamp (newest first) + assert all(f"checkpoint_{i:04d}.pt" in versions for i in range(10)) + + +def test_s3_backend_pull_single_file(s3_client): + """Test pull with single file (flat structure).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push a file + test_file = os.path.join(tmpdir, "checkpoint.pt") + with open(test_file, "w") as f: + f.write("model weights") + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=test_file) + + # Pull the file to a new directory + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull( + repo_name="org123/team456/exp1/ckpt", + version="checkpoint.pt", + output_dir=output_dir, + ) + + # Verify file was downloaded + assert len(result) == 1 + assert os.path.exists(result[0]) + assert os.path.basename(result[0]) == "checkpoint.pt" + + # Verify content + with open(result[0]) as f: + assert f.read() == "model weights" + + +def test_s3_backend_pull_version_folder(s3_client): + """Test pull with version folder (versioned structure).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Push multiple files with version + files = [] + for i in range(3): + file_path = os.path.join(tmpdir, f"file{i}.txt") + with open(file_path, "w") as f: + f.write(f"content {i}") + files.append(file_path) + + artifact.push(repo_name="org123/team456/exp1/ckpt", paths=files, version="v1") + + # Pull the version folder + output_dir = os.path.join(tmpdir, "download") + result = artifact.pull( + repo_name="org123/team456/exp1/ckpt", version="v1", output_dir=output_dir + ) + + # Verify all files were downloaded + assert len(result) == 3 + for i in range(3): + expected_file = os.path.join(output_dir, f"file{i}.txt") + assert any(expected_file == r for r in result) + assert os.path.exists(expected_file) + + with open(expected_file) as f: + assert f.read() == f"content {i}" + + +def test_s3_backend_pull_to_current_dir(s3_client): + """Test pull without output_dir (downloads to current directory).""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Change to temp directory + original_dir = os.getcwd() + try: + os.chdir(tmpdir) + + # Push a file + test_file = os.path.join(tmpdir, "test.txt") + with open(test_file, "w") as f: + f.write("test content") + artifact.push(repo_name="org123/team456/test-repo", paths=test_file) + + # Pull without output_dir + result = artifact.pull( + repo_name="org123/team456/test-repo", version="test.txt" + ) + + # Should download to current directory + assert len(result) == 1 + assert os.path.basename(result[0]) == "test.txt" + assert os.path.exists(result[0]) + + with open(result[0]) as f: + assert f.read() == "test content" + finally: + os.chdir(original_dir) + + +def test_s3_backend_pull_nonexistent_file(s3_client): + """Test pull with non-existent file raises error.""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(RuntimeError, match="Failed to pull artifacts"): + artifact.pull( + repo_name="org123/team456/nonexistent", + version="missing.txt", + output_dir=tmpdir, + ) + + +def test_s3_backend_pull_empty_version_folder(s3_client): + """Test pull with empty version folder returns empty list.""" + from alphatrion.artifact.artifact import Artifact + + artifact = Artifact() + + with tempfile.TemporaryDirectory() as tmpdir: + # Pull non-existent version folder + result = artifact.pull( + repo_name="org123/team456/exp1/ckpt", version="v999", output_dir=tmpdir + ) + + # Should return empty list + assert result == [] def test_s3_backend_path_based_versioning(s3_client):