@@ -45,13 +45,15 @@ def url_to_filename(url: str, etag: str = None) -> str:
4545 return filename
4646
4747
48- def filename_to_url (filename : str , cache_dir : str = None ) -> Tuple [str , str ]:
48+ def filename_to_url (filename : str , cache_dir : Union [ str , Path ] = None ) -> Tuple [str , str ]:
4949 """
5050 Return the url and etag (which may be ``None``) stored for `filename`.
5151 Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
5252 """
5353 if cache_dir is None :
5454 cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
55+ if isinstance (cache_dir , Path ):
56+ cache_dir = str (cache_dir )
5557
5658 cache_path = os .path .join (cache_dir , filename )
5759 if not os .path .exists (cache_path ):
@@ -69,7 +71,7 @@ def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
6971 return url , etag
7072
7173
72- def cached_path (url_or_filename : Union [str , Path ], cache_dir : str = None ) -> str :
74+ def cached_path (url_or_filename : Union [str , Path ], cache_dir : Union [ str , Path ] = None ) -> str :
7375 """
7476 Given something that might be a URL (or might be a local path),
7577 determine which. If it's a URL, download the file and cache it, and
@@ -80,6 +82,8 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str
8082 cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
8183 if isinstance (url_or_filename , Path ):
8284 url_or_filename = str (url_or_filename )
85+ if isinstance (cache_dir , Path ):
86+ cache_dir = str (cache_dir )
8387
8488 parsed = urlparse (url_or_filename )
8589
@@ -158,13 +162,15 @@ def http_get(url: str, temp_file: IO) -> None:
158162 progress .close ()
159163
160164
161- def get_from_cache (url : str , cache_dir : str = None ) -> str :
165+ def get_from_cache (url : str , cache_dir : Union [ str , Path ] = None ) -> str :
162166 """
163167 Given a URL, look for the corresponding dataset in the local cache.
164168 If it's not there, download it. Then return the path to the cached file.
165169 """
166170 if cache_dir is None :
167171 cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
172+ if isinstance (cache_dir , Path ):
173+ cache_dir = str (cache_dir )
168174
169175 os .makedirs (cache_dir , exist_ok = True )
170176
0 commit comments