diff --git a/fastsam/prompt.py b/fastsam/prompt.py index 4a2b900..e8097fa 100644 --- a/fastsam/prompt.py +++ b/fastsam/prompt.py @@ -166,16 +166,11 @@ def plot_to_result(self, plt.axis('off') fig = plt.gcf() - plt.draw() - - try: - buf = fig.canvas.tostring_rgb() - except AttributeError: - fig.canvas.draw() - buf = fig.canvas.tostring_rgb() + fig.canvas.draw() + buf = fig.canvas.buffer_rgba() cols, rows = fig.canvas.get_width_height() - img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) - result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) + img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4) + result = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR) plt.close() return result diff --git a/requirements.txt b/requirements.txt index b40e8ff..c17e219 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ scipy>=1.4.1 torch>=1.7.0 torchvision>=0.8.1 tqdm>=4.64.0 +psutil pandas>=1.1.4 seaborn>=0.11.0 diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 3c2ba06..496de4a 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +import pickle from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d, @@ -515,7 +516,7 @@ def torch_safe_load(weight): check_suffix(file=weight, suffix='.pt') file = attempt_download_asset(weight) # search online if missing locally try: - return torch.load(file, map_location='cpu'), file # load + return torch.load(file, map_location='cpu', weights_only=True), file # load except ModuleNotFoundError as e: # e.name is missing module name if e.name == 'models': raise TypeError( @@ -530,7 +531,12 @@ def torch_safe_load(weight): f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'") check_requirements(e.name) # install missing module - return torch.load(file, map_location='cpu'), file # load + return torch.load(file, map_location='cpu', weights_only=True), file # load + except (pickle.UnpicklingError, RuntimeError) as e: + LOGGER.warning(f"WARNING ⚠️ {weight} requires non-safe pickle loading. " + f"Falling back to weights_only=False. " + f"Only load weights from trusted sources.") + return torch.load(file, map_location='cpu', weights_only=False), file # load def attempt_load_weights(weights, device=None, inplace=True, fuse=False): diff --git a/utils/tools.py b/utils/tools.py index 9934c5c..25af72b 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -177,17 +177,12 @@ def fast_process( os.makedirs(save_path) plt.axis("off") fig = plt.gcf() - plt.draw() - - try: - buf = fig.canvas.tostring_rgb() - except AttributeError: - fig.canvas.draw() - buf = fig.canvas.tostring_rgb() - + fig.canvas.draw() + buf = fig.canvas.buffer_rgba() + cols, rows = fig.canvas.get_width_height() - img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3) - cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) + img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4) + cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)) # CPU post process