Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ classifiers = [
]
dependencies = [
"textual>=0.40",
"torch",
"typer>=0.9",
]

Expand Down
4 changes: 2 additions & 2 deletions src/sft/browser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Textual TUI application for browsing safetensors files."""
"""Textual TUI application for browsing tensor checkpoint files."""

from __future__ import annotations

Expand Down Expand Up @@ -680,7 +680,7 @@ def __init__(self) -> None:


class SftApp(App):
"""Interactive browser for .safetensors files."""
"""Interactive browser for tensor checkpoint files."""

TITLE = "sft"

Expand Down
12 changes: 7 additions & 5 deletions src/sft/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

from sft import __version__

SUPPORTED_EXTENSIONS = {".safetensors", ".pt", ".pth"}

app = typer.Typer(
name="sft",
help="An interactive terminal browser for .safetensors files.",
help="An interactive terminal browser for .safetensors and .pt/.pth files.",
no_args_is_help=True,
add_completion=False,
)
Expand All @@ -25,7 +27,7 @@ def version_callback(value: bool) -> None:
def main(
file: Path = typer.Argument(
...,
help="Path to a .safetensors file to browse.",
help="Path to a .safetensors, .pt, or .pth file to browse.",
exists=True,
file_okay=True,
dir_okay=False,
Expand All @@ -41,11 +43,11 @@ def main(
is_eager=True,
),
) -> None:
"""Open an interactive browser for a .safetensors file."""
"""Open an interactive browser for a .safetensors, .pt, or .pth file."""
# Validate file extension
if file.suffix.lower() != ".safetensors":
if file.suffix.lower() not in SUPPORTED_EXTENSIONS:
typer.secho(
f"Error: Expected a .safetensors file, got '{file.suffix}'",
f"Error: Expected a .safetensors, .pt, or .pth file, got '{file.suffix}'",
fg=typer.colors.RED,
err=True,
)
Expand Down
66 changes: 63 additions & 3 deletions src/sft/index.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Data model and parsing for safetensors files."""
"""Data model and parsing for tensor checkpoint files."""

from __future__ import annotations

import re
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -51,10 +52,21 @@ class TensorIndex:

@classmethod
def from_file(cls, path: Path) -> TensorIndex:
"""Parse a safetensors file and extract tensor metadata (header only).
"""Parse a tensor file based on extension.

This uses direct header parsing to avoid loading any tensor data.
Supports .safetensors, .pt, and .pth files.
"""
suffix = path.suffix.lower()
if suffix == ".safetensors":
return cls._from_safetensors(path)
elif suffix in (".pt", ".pth"):
return cls._from_pt_file(path)
else:
raise ValueError(f"Unsupported file format: '{suffix}'")

@classmethod
def _from_safetensors(cls, path: Path) -> TensorIndex:
"""Parse a safetensors file and extract tensor metadata (header only)."""
import json
import struct

Expand Down Expand Up @@ -100,6 +112,54 @@ def from_file(cls, path: Path) -> TensorIndex:

return cls(tensors=tensors, metadata=metadata, file_path=path)

@classmethod
def _from_pt_file(cls, path: Path) -> TensorIndex:
"""Parse a PyTorch .pt/.pth file and extract tensor metadata."""
import torch

try:
data = torch.load(path, map_location="cpu", weights_only=True)
except Exception:
warnings.warn(
"weights_only=True failed; falling back to weights_only=False. "
"Only load .pt files you trust.",
stacklevel=2,
)
data = torch.load(path, map_location="cpu", weights_only=False)

# Extract state dict from various formats
if isinstance(data, dict):
# Check for common state dict wrapper keys
for key in ("state_dict", "model_state_dict", "model"):
if key in data and isinstance(data[key], dict):
state_dict = data[key]
break
else:
state_dict = data
elif hasattr(data, "state_dict"):
state_dict = data.state_dict()
else:
raise ValueError(
"Unsupported .pt format: expected a dict or module with state_dict()"
)

tensors: list[TensorInfo] = []
for name, value in state_dict.items():
if not hasattr(value, "shape"):
continue
tensors.append(
TensorInfo(
full_name=name,
shape=tuple(value.shape),
dtype=str(value.dtype).replace("torch.", ""),
nbytes=value.nelement() * value.element_size(),
)
)

tensors.sort(key=lambda t: natural_sort_key(t.full_name))

return cls(tensors=tensors, metadata={}, file_path=path)

@property
def total_tensors(self) -> int:
"""Return the total number of tensors."""
Expand Down