diff --git a/extension_cpp/requirements.txt b/extension_cpp/requirements.txt index af3149e..f9c596d 100644 --- a/extension_cpp/requirements.txt +++ b/extension_cpp/requirements.txt @@ -1,2 +1,3 @@ torch numpy +packaging diff --git a/extension_cpp/setup.py b/extension_cpp/setup.py index 33a2a99..c0d017e 100644 --- a/extension_cpp/setup.py +++ b/extension_cpp/setup.py @@ -8,6 +8,7 @@ import glob from setuptools import find_packages, setup +from packaging.version import Version from torch.utils.cpp_extension import ( CppExtension, @@ -18,10 +19,7 @@ library_name = "extension_cpp" -if torch.__version__ >= "2.6.0": - py_limited_api = True -else: - py_limited_api = False +py_limited_api = Version(torch.__version__) >= Version("2.6.0") def get_extensions(): diff --git a/extension_cpp_stable/requirements.txt b/extension_cpp_stable/requirements.txt index af3149e..f9c596d 100644 --- a/extension_cpp_stable/requirements.txt +++ b/extension_cpp_stable/requirements.txt @@ -1,2 +1,3 @@ torch numpy +packaging diff --git a/extension_cpp_stable/setup.py b/extension_cpp_stable/setup.py index ca19488..435dc9b 100644 --- a/extension_cpp_stable/setup.py +++ b/extension_cpp_stable/setup.py @@ -8,6 +8,7 @@ import glob from setuptools import find_packages, setup +from packaging.version import Version from torch.utils.cpp_extension import ( CppExtension, @@ -19,10 +20,7 @@ library_name = "extension_cpp_stable" -if torch.__version__ >= "2.6.0": - py_limited_api = True -else: - py_limited_api = False +py_limited_api = Version(torch.__version__) >= Version("2.6.0") def get_extensions():