Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions anylabeling/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import os

from PyQt6.QtCore import QObject, pyqtSignal, pyqtSlot

MESH_EXTENSIONS = [".obj", ".stl", ".ply"]


def is_mesh_file(filename):
"""Check if the filename is a mesh file"""
if not filename:
return False
return os.path.splitext(filename)[1].lower() in MESH_EXTENSIONS


class GenericWorker(QObject):
finished = pyqtSignal()
Expand Down
40 changes: 27 additions & 13 deletions anylabeling/views/labeling/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import PIL.Image

from anylabeling.utils import is_mesh_file

from ...app_info import __version__
from . import utils
from .logger import logger
Expand All @@ -31,12 +33,17 @@ def __init__(self, filename=None):
self.shapes = []
self.image_path = None
self.image_data = None
self.flags = {}
self.other_data = {}
if filename is not None:
self.load(filename)
self.filename = filename

@staticmethod
def load_image_file(filename):
if is_mesh_file(filename):
return None

try:
image_pil = PIL.Image.open(filename)
except OSError:
Expand Down Expand Up @@ -74,6 +81,7 @@ def load(self, filename):
"group_id",
"shape_type",
"flags",
"vertex_indices",
]
try:
with io_open(filename, "r") as f:
Expand All @@ -82,30 +90,36 @@ def load(self, filename):
if version is None:
logger.warning("Loading JSON file (%s) of unknown version", filename)

if data["imageData"] is not None:
image_path = data.get("imagePath", "")
if is_mesh_file(image_path):
image_data = None
elif data.get("imageData") is not None:
image_data = base64.b64decode(data["imageData"])
else:
elif image_path:
# relative path from label file to relative path from cwd
image_path = osp.join(osp.dirname(filename), data["imagePath"])
image_data = self.load_image_file(image_path)
abs_image_path = osp.join(osp.dirname(filename), image_path)
image_data = self.load_image_file(abs_image_path)
else:
image_data = None
flags = data.get("flags") or {}
image_path = data["imagePath"]
self._check_image_height_and_width(
base64.b64encode(image_data).decode("utf-8"),
data.get("imageHeight"),
data.get("imageWidth"),
)
if image_data:
self._check_image_height_and_width(
base64.b64encode(image_data).decode("utf-8"),
data.get("imageHeight"),
data.get("imageWidth"),
)
shapes = [
{
"label": s["label"],
"label": s.get("label", ""),
"text": s.get("text", ""),
"points": s["points"],
"points": s.get("points", []),
"shape_type": s.get("shape_type", "polygon"),
"vertex_indices": s.get("vertex_indices", []),
"flags": s.get("flags", {}),
"group_id": s.get("group_id"),
"other_data": {k: v for k, v in s.items() if k not in shape_keys},
}
for s in data["shapes"]
for s in data.get("shapes", [])
]
except Exception as e: # noqa
raise LabelFileError(e) from e
Expand Down
Loading
Loading