Skip to content

Commit 32a227f

Browse files
authored
Merge pull request #113 from hzhwcmhf/master
fix compatibility with python 3.5.2
2 parents ffe9075 + 485adde commit 32a227f

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

pytorch_pretrained_bert/file_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ def url_to_filename(url: str, etag: str = None) -> str:
4545
return filename
4646

4747

48-
def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
48+
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
4949
"""
5050
Return the url and etag (which may be ``None``) stored for `filename`.
5151
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
5252
"""
5353
if cache_dir is None:
5454
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
55+
if isinstance(cache_dir, Path):
56+
cache_dir = str(cache_dir)
5557

5658
cache_path = os.path.join(cache_dir, filename)
5759
if not os.path.exists(cache_path):
@@ -69,7 +71,7 @@ def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
6971
return url, etag
7072

7173

72-
def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
74+
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
7375
"""
7476
Given something that might be a URL (or might be a local path),
7577
determine which. If it's a URL, download the file and cache it, and
@@ -80,6 +82,8 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
8082
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
8183
if isinstance(url_or_filename, Path):
8284
url_or_filename = str(url_or_filename)
85+
if isinstance(cache_dir, Path):
86+
cache_dir = str(cache_dir)
8387

8488
parsed = urlparse(url_or_filename)
8589

@@ -158,13 +162,15 @@ def http_get(url: str, temp_file: IO) -> None:
158162
progress.close()
159163

160164

161-
def get_from_cache(url: str, cache_dir: str = None) -> str:
165+
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
162166
"""
163167
Given a URL, look for the corresponding dataset in the local cache.
164168
If it's not there, download it. Then return the path to the cached file.
165169
"""
166170
if cache_dir is None:
167171
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
172+
if isinstance(cache_dir, Path):
173+
cache_dir = str(cache_dir)
168174

169175
os.makedirs(cache_dir, exist_ok=True)
170176

0 commit comments

Comments
 (0)