Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions SPECS/pytorch/CVE-2026-24747.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
From 52cc26db222976bbdf940ce110ad28bb5ea1cfc5 Mon Sep 17 00:00:00 2001
From: AllSpark <allspark@microsoft.com>
Date: Thu, 29 Jan 2026 14:25:44 +0000
Subject: [PATCH] override SWALR.state_dict and load_state_dict (#163122)

Fixes #163105

- Add typing_extensions.override
- Use _set_anneal_func to set anneal function
- Implement state_dict and load_state_dict for SWALR excluding optimizer and anneal_func

Signed-off-by: Azure Linux Security Servicing Account <azurelinux-security@microsoft.com>
Upstream-reference: AI Backport of https://github.com/pytorch/pytorch/commit/167ad09be5af5c52666759412a3804068c6955d1.patch
---
test/test_optim.py | 16 ++++++++++++++++
torch/optim/swa_utils.py | 37 +++++++++++++++++++++++++++++++++----
2 files changed, 49 insertions(+), 4 deletions(-)

diff --git a/test/test_optim.py b/test/test_optim.py
index 1608478b..d3dd4567 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -3968,6 +3968,22 @@ class TestLRScheduler(TestCase):

self.assertLessEqual(last_lr, max_lr)

+ @parametrize("LRClass", [partial(SWALR, swa_lr=0.01)])
+ @parametrize("weights_only", [True, False])
+ def test_lr_scheduler_state_dict_load(self, LRClass, weights_only):
+ scheduler = LRClass(self.opt)
+ state_dict = scheduler.state_dict()
+
+ with tempfile.TemporaryFile() as f:
+ torch.save(state_dict, f)
+ f.seek(0)
+ state_dict_loaded = torch.load(f, weights_only=weights_only)
+ self.assertEqual(state_dict, state_dict_loaded)
+ # Make sure state_dict can be loaded
+ scheduler2 = LRClass(self.opt)
+ scheduler2.load_state_dict(state_dict_loaded)
+ self.assertEqual(scheduler2.state_dict(), state_dict)
+

class SWATestDNN(torch.nn.Module):
def __init__(self, input_features):
diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py
index dda4b8ad..d18084e2 100644
--- a/torch/optim/swa_utils.py
+++ b/torch/optim/swa_utils.py
@@ -2,6 +2,7 @@ import itertools
import math
from copy import deepcopy
import warnings
+from typing_extensions import override

import torch
from torch.nn import Module
@@ -247,10 +248,7 @@ class SWALR(LRScheduler):
if anneal_strategy not in ['cos', 'linear']:
raise ValueError("anneal_strategy must by one of 'cos' or 'linear', "
f"instead got {anneal_strategy}")
- elif anneal_strategy == 'cos':
- self.anneal_func = self._cosine_anneal
- elif anneal_strategy == 'linear':
- self.anneal_func = self._linear_anneal
+ self._set_anneal_func(anneal_strategy)
if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
raise ValueError(f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}")
self.anneal_epochs = anneal_epochs
@@ -296,3 +294,34 @@ class SWALR(LRScheduler):
alpha = self.anneal_func(t)
return [group['swa_lr'] * alpha + lr * (1 - alpha)
for group, lr in zip(self.optimizer.param_groups, prev_lrs)]
+
+ def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]):
+ self._anneal_strategy = anneal_strategy
+ if anneal_strategy == "cos":
+ self.anneal_func = self._cosine_anneal
+ else:
+ self.anneal_func = self._linear_anneal
+
+ @override
+ def state_dict(self) -> dict[str, Any]:
+ """Return the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer or anneal_func.
+ """
+ return {
+ key: value
+ for key, value in self.__dict__.items()
+ if key not in ("optimizer", "anneal_func")
+ }
+
+ @override
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
+ """Load the scheduler's state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+ self._set_anneal_func(self._anneal_strategy)
--
2.45.4

6 changes: 5 additions & 1 deletion SPECS/pytorch/pytorch.spec
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration.
Name: pytorch
Version: 2.0.0
Release: 12%{?dist}
Release: 13%{?dist}
License: BSD-3-Clause
Vendor: Microsoft Corporation
Distribution: Mariner
Expand All @@ -23,6 +23,7 @@ Patch8: CVE-2025-2953.patch
Patch9: CVE-2025-55552.patch
Patch10: CVE-2025-55560.patch
Patch11: CVE-2025-3001.patch
Patch12: CVE-2026-24747.patch

BuildRequires: cmake
BuildRequires: gcc
Expand Down Expand Up @@ -95,6 +96,9 @@ cp -arf docs %{buildroot}/%{_pkgdocdir}
%{_docdir}/*

%changelog
* Thu Jan 29 2026 Azure Linux Security Servicing Account <azurelinux-security@microsoft.com> - 2.0.0-13
- Patch for CVE-2026-24747

* Thu Dec 25 2025 Azure Linux Security Servicing Account <azurelinux-security@microsoft.com> - 2.0.0-12
- Patch for CVE-2025-3001

Expand Down
Loading