Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion alphatrion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from alphatrion.log.load import load_dataset
from alphatrion.log.load import load_checkpoint, load_dataset
from alphatrion.log.log import log_artifact, log_dataset, log_metrics, log_params
from alphatrion.runtime.runtime import init

Expand All @@ -9,4 +9,5 @@
"log_metrics",
"log_dataset",
"load_dataset",
"load_checkpoint",
]
14 changes: 5 additions & 9 deletions alphatrion/artifact/oci_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,18 @@ def pull(
) -> list[str]:
path = f"{repo_name}:{version}"
target = f"{self._url}/{path}"
original_dir = None

if output_dir:
os.makedirs(output_dir, exist_ok=True)
original_dir = os.getcwd()
os.chdir(output_dir)
download_dir = os.path.abspath(output_dir)
else:
download_dir = os.getcwd()

try:
filenames = self._client.pull(target, outdir="." if output_dir else None)
download_dir = os.getcwd()
return [os.path.abspath(os.path.join(download_dir, f)) for f in filenames]
filenames = self._client.pull(target, outdir=download_dir)
return [os.path.join(download_dir, f) for f in filenames]
except Exception as e:
raise RuntimeError(f"Failed to pull artifacts: {e}") from e
finally:
if output_dir and original_dir:
os.chdir(original_dir)

def delete(self, repo_name: str, versions: str | list[str]):
target = f"{self._url}/{repo_name}"
Expand Down
107 changes: 105 additions & 2 deletions alphatrion/artifact/s3_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,115 @@ def push(
return repo_name if version is None else f"{repo_name}/{version}"

def list_versions(self, repo_name: str) -> list[str]:
raise NotImplementedError("list_versions is not implemented for S3 backend")
"""List all files directly under a repository path (ignores nested files).

Returns at most 3000 files (3 pages). If you have more checkpoints than this,
consider using database metadata to track versions instead.

:param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt")
:return: List of filenames sorted by LastModified (newest first), max 3000 items
"""
try:
prefix = f"{repo_name}/"
files_with_time = []
continuation_token = None
max_pages = 3 # Limit to 3000 files (1000 per page)
pages_fetched = 0

# Handle pagination for >1000 files, up to max_pages
while pages_fetched < max_pages:
# Use delimiter to only list top-level files, ignoring nested directories
params = {
"Bucket": self._bucket,
"Prefix": prefix,
"Delimiter": "/",
}
if continuation_token:
params["ContinuationToken"] = continuation_token

response = self._s3.list_objects_v2(**params)
pages_fetched += 1

if "Contents" in response:
# Extract filenames and timestamps
for obj in response["Contents"]:
s3_key = obj["Key"]
# Get filename: "repo_name/file.txt" -> "file.txt"
filename = s3_key[len(prefix) :]
if filename: # Skip empty
files_with_time.append((filename, obj["LastModified"]))

# Check if there are more results
if response.get("IsTruncated") and pages_fetched < max_pages:
continuation_token = response.get("NextContinuationToken")
else:
break

if not files_with_time:
return []

# Sort by LastModified descending (newest first)
files_with_time.sort(key=lambda x: x[1], reverse=True)

return [f[0] for f in files_with_time]
except Exception as e:
error_msg = str(e).lower()
if (
"404" in error_msg
or "not found" in error_msg
or "nosuchbucket" in error_msg
):
return []
raise RuntimeError(f"Failed to list versions: {e}") from e

def pull(
self, repo_name: str, version: str, output_dir: str | None = None
) -> list[str]:
raise NotImplementedError("pull is not implemented for S3 backend")
"""Pull (download) files from S3.

:param repo_name: Repository path (e.g., "org_id/team_id/exp_id/ckpt")
:param version: The filename to download (for flat structure) or folder name (for versioned structure)
:param output_dir: Optional directory to save files. If None, downloads to current directory.
:return: List of absolute paths to downloaded files
"""
if output_dir:
os.makedirs(output_dir, exist_ok=True)
download_dir = os.path.abspath(output_dir)
else:
download_dir = os.getcwd()

try:
# Check if version looks like a filename (has extension) or version folder
if "." in version:
# Single file: repo_name/version (e.g., "ckpt/checkpoint_123.pt")
s3_key = f"{repo_name}/{version}"
local_path = os.path.join(download_dir, version)

self._s3.download_file(self._bucket, s3_key, local_path)
return [local_path]
else:
# Version folder: repo_name/version/* (e.g., "ckpt/v1/*")
prefix = f"{repo_name}/{version}/"

response = self._s3.list_objects_v2(
Bucket=self._bucket, Prefix=prefix, Delimiter="/"
)

if "Contents" not in response:
return []

downloaded_files = []
for obj in response["Contents"]:
s3_key = obj["Key"]
filename = s3_key[len(prefix) :]
if filename: # Skip empty/directory markers
local_path = os.path.join(download_dir, filename)
self._s3.download_file(self._bucket, s3_key, local_path)
downloaded_files.append(local_path)

return downloaded_files
except Exception as e:
raise RuntimeError(f"Failed to pull artifacts from S3: {e}") from e

def delete(self, repo_name: str, versions: str | list[str]):
raise NotImplementedError("delete is not implemented for S3 backend")
Expand Down
42 changes: 42 additions & 0 deletions alphatrion/log/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,45 @@ async def load_dataset(id: str | uuid.UUID, output_dir: str | None = None) -> li
)

return result


async def load_checkpoint(
id: str | uuid.UUID,
version: str = "latest",
type: str = "experiment",
output_dir: str | None = None,
) -> list[str]:
"""
Load checkpoint from artifact registry.

:param id: the id of the experiment.
:param version: the version of the checkpoint to load, default is "latest".
For oci backend, version is the tag of the artifact.
For s3 backend, version is the name of the file to load.
If version is "latest", the most recently modified file will be loaded.
:param type: the type of the checkpoint, can be "experiment" or "agent", default is "experiment".
:param output_dir: the directory to which the checkpoint will be loaded.
"""
runtime = global_runtime()

if isinstance(id, str):
id = uuid.UUID(id)

artifact = runtime.artifact
if artifact is None:
raise RuntimeError("Artifact storage is not initialized in the runtime.")

repo_name = f"{runtime.org_id}/{runtime.team_id}/{id}/ckpt"

versions = artifact.list_versions(repo_name)
if versions is None or len(versions) == 0:
return []

if version == "latest":
version = versions[0] # Assuming versions are sorted by time, newest first

result = await asyncio.get_running_loop().run_in_executor(
None, artifact.pull, repo_name, version, output_dir
)

return result
4 changes: 1 addition & 3 deletions alphatrion/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ async def log_artifact(
# Now validate that we have paths
if not paths:
# TODO: replace with logging library.
print(
"Warning: No paths provided for log_artifact. Nothing will be logged."
)
print("Warning: No paths provided for log_artifact. Nothing will be logged.")

# We should still run the post_save_hook even if there's nothing to log,
# because the hook might have side effects that are important (e.g., cleanup).
Expand Down
Loading
Loading