Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions fastsam/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,11 @@ def plot_to_result(self,

plt.axis('off')
fig = plt.gcf()
plt.draw()

try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
fig.canvas.draw()
buf = fig.canvas.buffer_rgba()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
Comment thread
Xyc2016 marked this conversation as resolved.
result = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
plt.close()
return result

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ scipy>=1.4.1
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.64.0
psutil

pandas>=1.1.4
seaborn>=0.11.0
Expand Down
10 changes: 8 additions & 2 deletions ultralytics/nn/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.nn as nn
import pickle

from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
Expand Down Expand Up @@ -515,7 +516,7 @@ def torch_safe_load(weight):
check_suffix(file=weight, suffix='.pt')
file = attempt_download_asset(weight) # search online if missing locally
try:
return torch.load(file, map_location='cpu'), file # load
return torch.load(file, map_location='cpu', weights_only=True), file # load
except ModuleNotFoundError as e: # e.name is missing module name
if e.name == 'models':
raise TypeError(
Expand All @@ -530,7 +531,12 @@ def torch_safe_load(weight):
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
check_requirements(e.name) # install missing module

return torch.load(file, map_location='cpu'), file # load
return torch.load(file, map_location='cpu', weights_only=True), file # load
except (pickle.UnpicklingError, RuntimeError) as e:
LOGGER.warning(f"WARNING ⚠️ {weight} requires non-safe pickle loading. "
f"Falling back to weights_only=False. "
f"Only load weights from trusted sources.")
return torch.load(file, map_location='cpu', weights_only=False), file # load
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch_safe_load is intended to be a safer loader, but setting weights_only=False re-enables full pickle deserialization, which can lead to arbitrary code execution when loading untrusted .pt files. Recommendation: attempt torch.load(..., weights_only=True) first and only fall back to weights_only=False for a narrowly-scoped set of trusted/official weights (or gate it behind an explicit user opt-in / clear warning), or use PyTorch safe loading mechanisms (e.g., allowlisting needed globals) so third-party weights remain safe by default.

Copilot uses AI. Check for mistakes.


def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
Expand Down
15 changes: 5 additions & 10 deletions utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,12 @@ def fast_process(
os.makedirs(save_path)
plt.axis("off")
fig = plt.gcf()
plt.draw()

try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()

fig.canvas.draw()
buf = fig.canvas.buffer_rgba()

cols, rows = fig.canvas.get_width_height()
img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR))


# CPU post process
Expand Down