diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 0d5083fe..96f04544 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -582,6 +582,7 @@ def _get_delta_table( DeltaTable : If one exists at the specified bucket / prefix, will retrieve the cached instance. """ + delta_timeout = f"{self.timeout}s" full_key = f"{bucket}/{prefix}" qb_label = label or full_key.replace("/", "_").replace("-", "_") if (uri := f"{connector}://{full_key}") not in self._delta_tables: @@ -590,12 +591,46 @@ def _get_delta_table( storage_options={ "AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1", + "timeout": delta_timeout, + "connect_timeout": delta_timeout, + "retry_delay": "3", + "max_retries": f"{MAPI_CLIENT_SETTINGS.MAX_RETRIES}", }, ) self.query_builder.register(qb_label, self._delta_tables[uri]) return qb_label, self._delta_tables[uri] + def _query_delta_single(self, query: str) -> pa.Table: + """Execute a SQL query against a registered Delta table. + + Wraps the query execution in a try/except to provide a more + actionable error message when the underlying Delta query engine + fails (e.g., due to network timeouts, missing tables, or + malformed queries). + + Args: + query (str): A SQL query string compatible with the + QueryBuilder engine. + + Returns: + pa.Table: The query result as a PyArrow Table. + + Raises: + MPRestError: If query execution fails for any reason, + including network timeouts, connectivity issues, or + invalid queries. Inspect the chained exception for + the underlying cause. + """ + try: + return pa.table(self.query_builder.execute(query).read_all()) + except Exception as e: + raise MPRestError( + f"Failed to retrieve object due to: {e}. " + f"If this is a timeout error, try increasing the 'timeout' " + f"parameter on MPRester (current value: {self.timeout}s)." + ) from e + def _query_delta_backed( self, bucket: str, diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index 01cb6130..c97bf389 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -6,10 +6,7 @@ import pyarrow as pa from emmet.core.band_theory import BSPathType, ElectronicBS, ElectronicDos -from emmet.core.electronic_structure import ( - DOSProjectionType, - ElectronicStructureDoc, -) +from emmet.core.electronic_structure import DOSProjectionType, ElectronicStructureDoc from emmet.core.mpid import AlphaID from emmet.core.vasp.calc_types.enums import RunType from pymatgen.analysis.magnetism.analyzer import Ordering @@ -286,15 +283,19 @@ def get_bandstructure_from_task_id( label="bandstructure", ) - selection_string = f"""SELECT * -FROM {bs_lbl} -WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'""" + query = f""" + SELECT * + FROM {bs_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' + """ + if run_type: rt = RunType(run_type) if isinstance(run_type, str) else run_type - selection_string += f"\nAND run_type='{rt.value}'" + query += f"\nAND run_type='{rt.value}'" if path_type: - selection_string += f"\nAND path_convention='{path_type}'" - table = pa.table(self.query_builder.execute(selection_string)) + query += f"\nAND path_convention='{path_type}'" + + table = self._query_delta_single(query) if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: emmet_bs = ElectronicBS(**deser[0]) return emmet_bs.to_pmg( @@ -509,13 +510,17 @@ def get_dos_from_task_id( label="total_dos", ) - selection_string = f"""SELECT * -FROM {dos_lbl} -WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'""" + query = f""" + SELECT * + FROM {dos_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' + """ + if run_type: rt = RunType(run_type) if isinstance(run_type, str) else run_type - selection_string += f"\nAND run_type='{rt.value}'" - table = pa.table(self.query_builder.execute(selection_string)) + query += f"\nAND run_type='{rt.value}'" + + table = self._query_delta_single(query) if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: return ElectronicDos(**deser[0]).to_pmg() raise MPRestError( diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 3f3a7e31..af5dae5e 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -41,15 +41,13 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]: label="traj", ) - traj_data = pa.table( - self.query_builder.execute( - f""" - SELECT * - FROM {traj_lbl} - WHERE identifier='{as_alpha}' - """ - ).read_all() - ).to_pylist(maps_as_pydicts="strict") + query = f""" + SELECT * + FROM {traj_lbl} + WHERE identifier='{as_alpha}' + """ + + traj_data = self._query_delta_single(query).to_pylist(maps_as_pydicts="strict") if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index d1ffd28a..47cdf159 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -171,16 +171,15 @@ def get_phase_diagram_from_chemsys( pd_lbl, pd_tbl = self._get_delta_table( "materialsproject-build", "objects/phase-diagrams", label="phase_diagrams" ) - table = pa.table( - self.query_builder.execute( - f"""SELECT phase_diagram + + query = f""" + SELECT phase_diagram FROM {pd_lbl} WHERE chemsys='{sorted_chemsys}' - AND version='{version}' - AND thermo_type='{thermo_type}' - """ - ) - ) + AND version='{version}' + AND thermo_type='{thermo_type}' + """ + table = self._query_delta_single(query) as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict") pd: PhaseDiagram | None = None