From 0ee2c94df890d1d2cdb7516a927f9030b84d4b44 Mon Sep 17 00:00:00 2001 From: primepake Date: Wed, 4 Mar 2026 16:32:13 +0700 Subject: [PATCH] update support .pt/.pth --- pyproject.toml | 1 + src/sft/browser.py | 4 +-- src/sft/cli.py | 12 +++++---- src/sft/index.py | 66 +++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 73 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d1d7571..7f26796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ ] dependencies = [ "textual>=0.40", + "torch", "typer>=0.9", ] diff --git a/src/sft/browser.py b/src/sft/browser.py index 12f7437..a2ddb48 100644 --- a/src/sft/browser.py +++ b/src/sft/browser.py @@ -1,4 +1,4 @@ -"""Textual TUI application for browsing safetensors files.""" +"""Textual TUI application for browsing tensor checkpoint files.""" from __future__ import annotations @@ -680,7 +680,7 @@ def __init__(self) -> None: class SftApp(App): - """Interactive browser for .safetensors files.""" + """Interactive browser for tensor checkpoint files.""" TITLE = "sft" diff --git a/src/sft/cli.py b/src/sft/cli.py index ad5a4ad..e7a7bb2 100644 --- a/src/sft/cli.py +++ b/src/sft/cli.py @@ -6,9 +6,11 @@ from sft import __version__ +SUPPORTED_EXTENSIONS = {".safetensors", ".pt", ".pth"} + app = typer.Typer( name="sft", - help="An interactive terminal browser for .safetensors files.", + help="An interactive terminal browser for .safetensors and .pt/.pth files.", no_args_is_help=True, add_completion=False, ) @@ -25,7 +27,7 @@ def version_callback(value: bool) -> None: def main( file: Path = typer.Argument( ..., - help="Path to a .safetensors file to browse.", + help="Path to a .safetensors, .pt, or .pth file to browse.", exists=True, file_okay=True, dir_okay=False, @@ -41,11 +43,11 @@ def main( is_eager=True, ), ) -> None: - """Open an interactive browser for a .safetensors file.""" + """Open an interactive browser for a .safetensors, .pt, or .pth file.""" # Validate file extension - if file.suffix.lower() != ".safetensors": + if file.suffix.lower() not in SUPPORTED_EXTENSIONS: typer.secho( - f"Error: Expected a .safetensors file, got '{file.suffix}'", + f"Error: Expected a .safetensors, .pt, or .pth file, got '{file.suffix}'", fg=typer.colors.RED, err=True, ) diff --git a/src/sft/index.py b/src/sft/index.py index adcb833..fbf536d 100644 --- a/src/sft/index.py +++ b/src/sft/index.py @@ -1,8 +1,9 @@ -"""Data model and parsing for safetensors files.""" +"""Data model and parsing for tensor checkpoint files.""" from __future__ import annotations import re +import warnings from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -51,10 +52,21 @@ class TensorIndex: @classmethod def from_file(cls, path: Path) -> TensorIndex: - """Parse a safetensors file and extract tensor metadata (header only). + """Parse a tensor file based on extension. - This uses direct header parsing to avoid loading any tensor data. + Supports .safetensors, .pt, and .pth files. """ + suffix = path.suffix.lower() + if suffix == ".safetensors": + return cls._from_safetensors(path) + elif suffix in (".pt", ".pth"): + return cls._from_pt_file(path) + else: + raise ValueError(f"Unsupported file format: '{suffix}'") + + @classmethod + def _from_safetensors(cls, path: Path) -> TensorIndex: + """Parse a safetensors file and extract tensor metadata (header only).""" import json import struct @@ -100,6 +112,54 @@ def from_file(cls, path: Path) -> TensorIndex: return cls(tensors=tensors, metadata=metadata, file_path=path) + @classmethod + def _from_pt_file(cls, path: Path) -> TensorIndex: + """Parse a PyTorch .pt/.pth file and extract tensor metadata.""" + import torch + + try: + data = torch.load(path, map_location="cpu", weights_only=True) + except Exception: + warnings.warn( + "weights_only=True failed; falling back to weights_only=False. " + "Only load .pt files you trust.", + stacklevel=2, + ) + data = torch.load(path, map_location="cpu", weights_only=False) + + # Extract state dict from various formats + if isinstance(data, dict): + # Check for common state dict wrapper keys + for key in ("state_dict", "model_state_dict", "model"): + if key in data and isinstance(data[key], dict): + state_dict = data[key] + break + else: + state_dict = data + elif hasattr(data, "state_dict"): + state_dict = data.state_dict() + else: + raise ValueError( + "Unsupported .pt format: expected a dict or module with state_dict()" + ) + + tensors: list[TensorInfo] = [] + for name, value in state_dict.items(): + if not hasattr(value, "shape"): + continue + tensors.append( + TensorInfo( + full_name=name, + shape=tuple(value.shape), + dtype=str(value.dtype).replace("torch.", ""), + nbytes=value.nelement() * value.element_size(), + ) + ) + + tensors.sort(key=lambda t: natural_sort_key(t.full_name)) + + return cls(tensors=tensors, metadata={}, file_path=path) + @property def total_tensors(self) -> int: """Return the total number of tensors."""