Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def _get_delta_table(
DeltaTable : If one exists at the specified bucket / prefix,
will retrieve the cached instance.
"""
delta_timeout = f"{self.timeout}s"
full_key = f"{bucket}/{prefix}"
qb_label = label or full_key.replace("/", "_").replace("-", "_")
if (uri := f"{connector}://{full_key}") not in self._delta_tables:
Expand All @@ -590,12 +591,46 @@ def _get_delta_table(
storage_options={
"AWS_SKIP_SIGNATURE": "true",
"AWS_REGION": "us-east-1",
"timeout": delta_timeout,
"connect_timeout": delta_timeout,
"retry_delay": "3",
"max_retries": f"{MAPI_CLIENT_SETTINGS.MAX_RETRIES}",
},
)
self.query_builder.register(qb_label, self._delta_tables[uri])

return qb_label, self._delta_tables[uri]

def _query_delta_single(self, query: str) -> pa.Table:
"""Execute a SQL query against a registered Delta table.

Wraps the query execution in a try/except to provide a more
actionable error message when the underlying Delta query engine
fails (e.g., due to network timeouts, missing tables, or
malformed queries).

Args:
query (str): A SQL query string compatible with the
QueryBuilder engine.

Returns:
pa.Table: The query result as a PyArrow Table.

Raises:
MPRestError: If query execution fails for any reason,
including network timeouts, connectivity issues, or
invalid queries. Inspect the chained exception for
the underlying cause.
"""
try:
return pa.table(self.query_builder.execute(query).read_all())
except Exception as e:
raise MPRestError(
f"Failed to retrieve object due to: {e}. "
f"If this is a timeout error, try increasing the 'timeout' "
f"parameter on MPRester (current value: {self.timeout}s)."
) from e

def _query_delta_backed(
self,
bucket: str,
Expand Down
35 changes: 20 additions & 15 deletions mp_api/client/routes/materials/electronic_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@

import pyarrow as pa
from emmet.core.band_theory import BSPathType, ElectronicBS, ElectronicDos
from emmet.core.electronic_structure import (
DOSProjectionType,
ElectronicStructureDoc,
)
from emmet.core.electronic_structure import DOSProjectionType, ElectronicStructureDoc
from emmet.core.mpid import AlphaID
from emmet.core.vasp.calc_types.enums import RunType
from pymatgen.analysis.magnetism.analyzer import Ordering
Expand Down Expand Up @@ -286,15 +283,19 @@ def get_bandstructure_from_task_id(
label="bandstructure",
)

selection_string = f"""SELECT *
FROM {bs_lbl}
WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'"""
query = f"""
SELECT *
FROM {bs_lbl}
WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'
"""

if run_type:
rt = RunType(run_type) if isinstance(run_type, str) else run_type
selection_string += f"\nAND run_type='{rt.value}'"
query += f"\nAND run_type='{rt.value}'"
if path_type:
selection_string += f"\nAND path_convention='{path_type}'"
table = pa.table(self.query_builder.execute(selection_string))
query += f"\nAND path_convention='{path_type}'"

table = self._query_delta_single(query)
if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0:
emmet_bs = ElectronicBS(**deser[0])
return emmet_bs.to_pmg(
Expand Down Expand Up @@ -509,13 +510,17 @@ def get_dos_from_task_id(
label="total_dos",
)

selection_string = f"""SELECT *
FROM {dos_lbl}
WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'"""
query = f"""
SELECT *
FROM {dos_lbl}
WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}'
"""

if run_type:
rt = RunType(run_type) if isinstance(run_type, str) else run_type
selection_string += f"\nAND run_type='{rt.value}'"
table = pa.table(self.query_builder.execute(selection_string))
query += f"\nAND run_type='{rt.value}'"

table = self._query_delta_single(query)
if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0:
return ElectronicDos(**deser[0]).to_pmg()
raise MPRestError(
Expand Down
16 changes: 7 additions & 9 deletions mp_api/client/routes/materials/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,13 @@ def get_trajectory(self, task_id: MPID | AlphaID | str) -> dict[str, Any]:
label="traj",
)

traj_data = pa.table(
self.query_builder.execute(
f"""
SELECT *
FROM {traj_lbl}
WHERE identifier='{as_alpha}'
"""
).read_all()
).to_pylist(maps_as_pydicts="strict")
query = f"""
SELECT *
FROM {traj_lbl}
WHERE identifier='{as_alpha}'
"""

traj_data = self._query_delta_single(query).to_pylist(maps_as_pydicts="strict")

if not traj_data:
raise MPRestError(f"No trajectory data for {task_id} found")
Expand Down
15 changes: 7 additions & 8 deletions mp_api/client/routes/materials/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,15 @@ def get_phase_diagram_from_chemsys(
pd_lbl, pd_tbl = self._get_delta_table(
"materialsproject-build", "objects/phase-diagrams", label="phase_diagrams"
)
table = pa.table(
self.query_builder.execute(
f"""SELECT phase_diagram

query = f"""
SELECT phase_diagram
FROM {pd_lbl}
WHERE chemsys='{sorted_chemsys}'
AND version='{version}'
AND thermo_type='{thermo_type}'
"""
)
)
AND version='{version}'
AND thermo_type='{thermo_type}'
"""
table = self._query_delta_single(query)
as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict")

pd: PhaseDiagram | None = None
Expand Down