Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions alphatrion/artifact/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,36 @@

SUCCESS_CODE = 201

ARTIFACT_TYPE_S3 = "s3"
ARTIFACT_TYPE_OCI = "oci"


class Artifact:
"""Artifact storage client with pluggable backends (OCI or S3)."""

def __init__(self, insecure: bool = False):
storage_type = os.environ.get(envs.ARTIFACT_STORAGE_TYPE, "oci").lower()
self._storage_type = os.environ.get(
envs.ARTIFACT_STORAGE_TYPE, ARTIFACT_TYPE_OCI
).lower()

if storage_type == "s3":
if self._storage_type == ARTIFACT_TYPE_S3:
from alphatrion.artifact.s3_backend import S3Backend

self._backend = S3Backend()
elif storage_type == "oci":
elif self._storage_type == ARTIFACT_TYPE_OCI:
from alphatrion.artifact.oci_backend import OCIBackend

self._backend = OCIBackend(insecure=insecure)
else:
raise ValueError(
f"Unsupported artifact storage type: {storage_type}. "
f"Supported types: 'oci', 's3'"
f"Unsupported artifact storage type: {self._storage_type}. "
f"Supported types: '{ARTIFACT_TYPE_OCI}', '{ARTIFACT_TYPE_S3}'"
)

@property
def storage_type(self):
return self._storage_type

def push(
self,
repo_name: str,
Expand Down
11 changes: 7 additions & 4 deletions alphatrion/log/load.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import uuid

from alphatrion.artifact.artifact import ARTIFACT_TYPE_S3
from alphatrion.runtime.runtime import global_runtime


Expand Down Expand Up @@ -55,11 +56,13 @@ async def load_checkpoint(

repo_name = f"{runtime.org_id}/{runtime.team_id}/{id}/ckpt"

versions = artifact.list_versions(repo_name)
if versions is None or len(versions) == 0:
return []
# We only need to do this for s3 backend, because for oci backend,
# the version is the tag and "latest" tag will always point to the latest version.
if version == "latest" and artifact.storage_type == ARTIFACT_TYPE_S3:
versions = artifact.list_versions(repo_name)
if versions is None or len(versions) == 0:
return []

if version == "latest":
version = versions[0] # Assuming versions are sorted by time, newest first

result = await asyncio.get_running_loop().run_in_executor(
Expand Down
37 changes: 22 additions & 15 deletions tests/integration/test_oci_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,7 @@ async def test_load_checkpoint_latest(artifact):
alpha.init(org_id=org_id, team_id=team_id, user_id=user_id)

with tempfile.TemporaryDirectory() as tmpdir:
# Push multiple checkpoint versions with different content and timestamps
import time

# Push multiple checkpoint versions with different content
for i in range(3):
test_file = os.path.join(tmpdir, f"checkpoint_{i}.pt")
with open(test_file, "w") as f:
Expand All @@ -329,10 +327,19 @@ async def test_load_checkpoint_latest(artifact):
paths=test_file,
version=f"v{i}",
)
if i < 2:
time.sleep(0.1) # Small delay for timestamp ordering

# Load latest checkpoint (should be v2 for S3, but arbitrary for OCI)
# For OCI, also push a version tagged as "latest"
latest_file = os.path.join(tmpdir, "checkpoint_latest.pt")
with open(latest_file, "w") as f:
f.write("model weights version latest")

artifact.push(
repo_name=f"{org_id}/{team_id}/{exp_id}/ckpt",
paths=latest_file,
version="latest",
)

# Load latest checkpoint
output_dir = os.path.join(tmpdir, "download")
result = await alpha.load_checkpoint(
id=exp_id, version="latest", output_dir=output_dir
Expand All @@ -343,10 +350,10 @@ async def test_load_checkpoint_latest(artifact):
assert len(result) == 1
assert os.path.exists(result[0])

# Verify it's one of the versions
# For OCI, "latest" tag should return the checkpoint tagged as "latest"
with open(result[0]) as f:
content = f.read()
assert content.startswith("model weights version")
assert content == "model weights version latest"


@pytest.mark.asyncio
Expand Down Expand Up @@ -416,7 +423,7 @@ async def test_load_checkpoint_specific_version(artifact):

@pytest.mark.asyncio
async def test_load_checkpoint_nonexistent(artifact):
"""Test load_checkpoint returns None for nonexistent experiment."""
"""Test load_checkpoint with nonexistent checkpoint tag raises error for OCI."""
org_id = uuid.uuid4()
team_id = uuid.uuid4()
user_id = uuid.uuid4()
Expand All @@ -425,12 +432,12 @@ async def test_load_checkpoint_nonexistent(artifact):
alpha.init(org_id=org_id, team_id=team_id, user_id=user_id)

with tempfile.TemporaryDirectory() as tmpdir:
# Try to load checkpoint from non-existent experiment
result = await alpha.load_checkpoint(
id=exp_id, version="latest", output_dir=tmpdir
)

assert result == []
# For OCI, trying to pull a non-existent tag should raise an error
# (unlike S3 which returns [] when no files exist)
with pytest.raises(RuntimeError, match="Failed to pull artifacts"):
await alpha.load_checkpoint(
id=exp_id, version="latest", output_dir=tmpdir
)


@pytest.mark.asyncio
Expand Down
Loading