diff --git a/dev/generate_mcp_tools.py b/dev/generate_mcp_tools.py index 4bd57700..57ca9511 100644 --- a/dev/generate_mcp_tools.py +++ b/dev/generate_mcp_tools.py @@ -53,13 +53,13 @@ def regenerate_tools( from datetime import datetime from typing import Literal +from emmet.core.band_theory import BSPathType from emmet.core.chemenv import ( COORDINATION_GEOMETRIES, COORDINATION_GEOMETRIES_IUCR, COORDINATION_GEOMETRIES_IUPAC, COORDINATION_GEOMETRIES_NAMES, ) -from emmet.core.band_theory import BSPathType from emmet.core.electronic_structure import DOSProjectionType from emmet.core.grain_boundary import GBTypeEnum from emmet.core.mpid import MPID diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index dac0076a..5398fbdb 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -43,6 +43,7 @@ from tqdm.auto import tqdm from urllib3.util.retry import Retry +from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, @@ -52,7 +53,6 @@ from mp_api.client.core.utils import ( MPDataset, load_json, - validate_api_key, validate_endpoint, validate_ids, ) @@ -70,6 +70,17 @@ except PackageNotFoundError: # pragma: no cover __version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "") +STATIC_COLLECTIONS = [ + "eos", + "grain_boundaries", + "jcesr", + "molecules", + "phonon", + "snls", + "surface-properties", + "synth-descriptions", + "xas", +] hdlr = logging.StreamHandler() fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s") @@ -105,33 +116,26 @@ def get(self, item: str, default: Any = None) -> Any: return default -class BaseRester: - """Base client class with core stubs.""" - - suffix: str = "" - document_model: type[BaseModel] = _DictLikeAccess - primary_key: str = "material_id" - delta_backed: bool = False +class _Rester: + """Define base attributes of a REST client.""" def __init__( self, api_key: str | None = None, endpoint: str | None = None, include_user_agent: bool = True, - session: requests.Session | None = None, - s3_client: Any | None = None, - debug: bool = False, use_document_model: bool = True, - timeout: int = 20, + session: requests.Session | None = None, headers: dict | None = None, mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, local_dataset_cache: ( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, + query_builder: QueryBuilder | None = None, **kwargs, - ): - """Initialize the REST API helper class. + ) -> None: + """Initialize a RESTer. Arguments: api_key: A String API key for accessing the MaterialsProject @@ -150,49 +154,49 @@ def __init__( making the API request. This helps MP support pymatgen users, and is similar to what most web browsers send with each page request. Set to False to disable the user agent. - session: requests Session object with which to connect to the API, for - advanced usage only. - s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. - debug: if True, print the URL for every request use_document_model: If False, skip the creating the document model and return data as a dictionary. This can be simpler to work with but bypasses data validation and will not give auto-complete for available fields. - timeout: Time in seconds to wait until a request timeout error is thrown + session: requests Session object with which to connect to the API, for + advanced usage only. headers: Custom headers for localhost connections. mute_progress_bars: Whether to disable progress bars. local_dataset_cache: Target directory for downloading full datasets. Defaults to 'mp_datasets' in the user's home directory force_renew: Option to overwrite existing local dataset + query_builder : Instance of deltalake QueryBuilder to use in querying delta tables **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = validate_api_key(api_key) - self.base_endpoint = validate_endpoint(endpoint) - self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) + self.api_key = get_user_api_key(api_key=api_key) + self.endpoint = validate_endpoint(endpoint) - self.debug = debug self.include_user_agent = include_user_agent self.use_document_model = use_document_model - self.timeout = timeout - self.headers = headers or {} - self.mute_progress_bars = mute_progress_bars - ( - self.db_version, - self.access_controlled_batch_ids, - ) = BaseRester._get_heartbeat_info(self.base_endpoint) + self.headers = headers or get_consumer() + self._session = session or _Rester._create_session( + api_key=self.api_key, + include_user_agent=self.include_user_agent, + headers=self.headers, + ) - self.local_dataset_cache: Path = Path(local_dataset_cache) - self.force_renew = force_renew + if is_dev_env(): + self._session.headers["x-api-key"] = self.api_key or "" - self._session = session - self._s3_client = s3_client + self.use_document_model = use_document_model + self.mute_progress_bars = mute_progress_bars + self.local_dataset_cache = Path(local_dataset_cache) + self.force_renew = force_renew + self._query_builder = query_builder if "monty_decode" in kwargs: + # Pop to not repeatedly trigger warning to the user + kwargs.pop("monty_decode", None) warnings.warn( "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." "The client by default returns results consistent with `monty_decode=True`.", - category=MPRestWarning, stacklevel=2, + category=MPRestWarning, ) @property @@ -204,13 +208,10 @@ def session(self) -> requests.Session: return self._session @property - def s3_client(self): - if not self._s3_client: - self._s3_client = boto3.client( - "s3", - config=Config(signature_version=UNSIGNED), # type: ignore - ) - return self._s3_client + def query_builder(self): + if not self._query_builder: + self._query_builder = QueryBuilder() + return self._query_builder @staticmethod def _create_session(api_key, include_user_agent, headers): @@ -250,6 +251,105 @@ def __exit__(self, exc_type, exc_val, exc_tb): # pragma: no cover self.session.close() self._session = None + +class BaseRester(_Rester): + """Base client class with core stubs.""" + + suffix: str = "" + document_model: type[BaseModel] = _DictLikeAccess + primary_key: str = "material_id" + delta_backed: bool = True + + def __init__( + self, + api_key: str | None = None, + endpoint: str | None = None, + include_user_agent: bool = True, + use_document_model: bool = True, + session: requests.Session | None = None, + headers: dict | None = None, + mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS, + local_dataset_cache: ( + str | os.PathLike + ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, + force_renew: bool = False, + query_builder: QueryBuilder | None = None, + s3_client: Any | None = None, + timeout: int = 20, + **kwargs, + ): + """Initialize the REST API helper class. + + s3_client: boto3 S3 client object with which to connect to the object stores. + timeout: Time in seconds to wait until a request timeout error is thrown + + Arguments: + api_key: A String API key for accessing the MaterialsProject + REST interface. Please obtain your API key at + https://www.materialsproject.org/dashboard. If this is None, + the code will check if there is a "PMG_MAPI_KEY" setting. + If so, it will use that environment variable. This makes + easier for heavy users to simply add this environment variable to + their setups and MPRester can then be called without any arguments. + endpoint: Url of endpoint to access the MaterialsProject REST + interface. Defaults to the standard Materials Project REST + address at "https://api.materialsproject.org", but + can be changed to other urls implementing a similar interface. + include_user_agent: If True, will include a user agent with the + HTTP request including information on pymatgen and system version + making the API request. This helps MP support pymatgen users, and + is similar to what most web browsers send with each page request. + Set to False to disable the user agent. + session: requests Session object with which to connect to the API, for + advanced usage only. + use_document_model: If False, skip the creating the document model and return data + as a dictionary. This can be simpler to work with but bypasses data validation + and will not give auto-complete for available fields. + headers: Custom headers for localhost connections. + mute_progress_bars: Whether to disable progress bars. + local_dataset_cache: Target directory for downloading full datasets. Defaults + to 'mp_datasets' in the user's home directory + force_renew: Option to overwrite existing local dataset + query_builder : Instance of deltalake QueryBuilder to use in querying delta tables + s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores. + timeout: Time in seconds to wait until a request timeout error is thrown + **kwargs: access to legacy kwargs that may be in the process of being deprecated + """ + super().__init__( + api_key=api_key, + endpoint=endpoint, + include_user_agent=include_user_agent, + use_document_model=use_document_model, + session=session, + headers=headers, + mute_progress_bars=mute_progress_bars, + local_dataset_cache=local_dataset_cache, + force_renew=force_renew, + query_builder=query_builder, + ) + + self.base_endpoint = validate_endpoint(endpoint) + self.endpoint = validate_endpoint(endpoint, suffix=self.suffix) + + ( + self.db_version, + self.access_controlled_batch_ids, + ) = BaseRester._get_heartbeat_info(self.base_endpoint) + + self.timeout = timeout + self._s3_client = s3_client + + self._delta_tables: dict[str, DeltaTable] = {} + + @property + def s3_client(self): + if not self._s3_client: + self._s3_client = boto3.client( + "s3", + config=Config(signature_version=UNSIGNED), # type: ignore + ) + return self._s3_client + @staticmethod @cache def _get_heartbeat_info(endpoint) -> tuple[str, list[str]]: @@ -459,11 +559,85 @@ def _query_open_data( return decoded_data, len(decoded_data) # type: ignore + def _get_delta_table( + self, + bucket: str, + prefix: str, + connector: str = "s3a", + label: str | None = None, + ) -> tuple[str, DeltaTable]: + """Either create a new DeltaTable, or retrieve a cached one. + + If creating a new DeltaTable, will also register in self.query_builder + + Args: + bucket (str) : name of the bucket in S3 + prefix (str) : name of the prefix in S3 + connector (str) : s3, s3n, s3a (default), or other + valid Hadoop connector string. + label (str or None) : optional label for the table in QueryBuilder + If `None`, will be gleaned from the URI + + Returns: + str : the table name in QueryBuilder + DeltaTable : If one exists at the specified bucket / prefix, + will retrieve the cached instance. + """ + delta_timeout = f"{self.timeout}s" + full_key = f"{bucket}/{prefix}" + qb_label = label or full_key.replace("/", "_").replace("-", "_") + if (uri := f"{connector}://{full_key}") not in self._delta_tables: + self._delta_tables[uri] = DeltaTable( + uri, + storage_options={ + "AWS_SKIP_SIGNATURE": "true", + "AWS_REGION": "us-east-1", + "timeout": delta_timeout, + "connect_timeout": delta_timeout, + "retry_delay": "3", + "max_retries": f"{MAPI_CLIENT_SETTINGS.MAX_RETRIES}", + }, + ) + self.query_builder.register(qb_label, self._delta_tables[uri]) + + return qb_label, self._delta_tables[uri] + + def _query_delta_single(self, query: str) -> pa.Table: + """Execute a SQL query against a registered Delta table. + + Wraps the query execution in a try/except to provide a more + actionable error message when the underlying Delta query engine + fails (e.g., due to network timeouts, missing tables, or + malformed queries). + + Args: + query (str): A SQL query string compatible with the + QueryBuilder engine. + + Returns: + pa.Table: The query result as a PyArrow Table. + + Raises: + MPRestError: If query execution fails for any reason, + including network timeouts, connectivity issues, or + invalid queries. Inspect the chained exception for + the underlying cause. + """ + try: + return pa.table(self.query_builder.execute(query).read_all()) + except Exception as e: + raise MPRestError( + f"Failed to retrieve object due to: {e}. " + f"If this is a timeout error, try increasing the 'timeout' " + f"parameter on MPRester (current value: {self.timeout}s)." + ) from e + def _query_delta_backed( self, bucket: str, prefix: str, timeout: int | None = None, + label: str | None = None, ) -> dict[str, Any]: """Retrieve data from S3 backed by a DeltaTable. @@ -471,6 +645,7 @@ def _query_delta_backed( bucket (str) : S3 OpenData bucket prefix (str) : S3 object prefix timeout (int or None) : timeout on getting access-controlled groups + label (str or None) : label of the table in QueryBuilder Returns: dict of str to Any @@ -527,13 +702,7 @@ def _query_delta_backed( ) } - tbl = DeltaTable( - f"s3a://{bucket}/{prefix}", - storage_options={ - "AWS_SKIP_SIGNATURE": "true", - "AWS_REGION": "us-east-1", - }, - ) + tbl_lbl, tbl = self._get_delta_table(bucket, prefix, label=label) controlled_batch_str = ",".join( [f"'{tag}'" for tag in self.access_controlled_batch_ids] @@ -545,8 +714,6 @@ def _query_delta_backed( else "" ) - builder = QueryBuilder().register("tbl", tbl) - # Setup progress bar num_docs_needed: int = tbl.count() @@ -568,7 +735,7 @@ def _query_delta_backed( else None ) - iterator = builder.execute(f"SELECT * FROM tbl {predicate}") + iterator = self.query_builder.execute(f"SELECT * FROM {tbl_lbl} {predicate}") file_options = ds.ParquetFileFormat().make_write_options(compression="zstd") @@ -714,6 +881,9 @@ def _query_resource( if "tasks" in suffix: bucket_suffix, prefix = ("parsed", "core/tasks/") + elif suffix in STATIC_COLLECTIONS: + bucket_suffix = "build" + prefix = f"static-collections/{suffix}" else: bucket_suffix = "build" prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}" @@ -1267,12 +1437,7 @@ def _convert_to_model( ) return [ - data_model( - **{ - field: raw_doc[field] - for field in set_fields.intersection(raw_doc) - } - ) + data_model(**raw_doc) for raw_doc in (data if is_list else chain([first_doc], data)) ] @@ -1295,7 +1460,14 @@ def _generate_returned_model( set of str: set_fields, fields_not_requested) """ model_fields = self.document_model.model_fields - set_fields = set(doc).intersection(model_fields) + aliases = { + anno.alias: field for field, anno in model_fields.items() if anno.alias + } + set_fields = ( + set(doc) + .intersection(model_fields) + .union({aliases[k] for k in set(doc).intersection(aliases)}) + ) unset_fields = set(model_fields).difference(set_fields) user_requested_fields: list[str] = requested_fields or [] fields_not_requested = unset_fields.difference(user_requested_fields) @@ -1619,6 +1791,7 @@ def __getattr__(self, v: str): mute_progress_bars=self.mute_progress_bars, local_dataset_cache=self.local_dataset_cache, force_renew=self.force_renew, + query_builder=self._query_builder, ) return self.sub_resters[v] raise AttributeError(f"{self.__class__} has no attribute {v}") diff --git a/mp_api/client/core/utils.py b/mp_api/client/core/utils.py index 5c15f8b4..a6857ef1 100644 --- a/mp_api/client/core/utils.py +++ b/mp_api/client/core/utils.py @@ -116,10 +116,8 @@ def validate_ids(id_list: list[str]) -> list[str]: " data for all IDs and filter locally." ) - # TODO: after the transition to AlphaID in the document models, - # The following line should be changed to - # return [validate_identifier(idx,serialize=True) for idx in id_list] - return [str(validate_identifier(idx)) for idx in id_list] + validated = [validate_identifier(idx, serialize=False) for idx in id_list] + return [getattr(idx, "string", str(idx)) for idx in validated] def validate_endpoint(endpoint: str | None, suffix: str | None = None) -> str: @@ -243,6 +241,14 @@ def __getattr__(self, v: str) -> Any: if hasattr(self._imported, v): return getattr(self._imported, v) + raise AttributeError( + f"{self._module_name}{'.' + self._class_name if self._class_name else ''} " + f"has no attribute {v}" + ) + + def __dir__(self) -> list[str]: + return self._obj.__dir__() + class MPDataset: """Convenience wrapper for pyarrow datasets stored on disk.""" diff --git a/mp_api/client/mprester.py b/mp_api/client/mprester.py index d1b5b069..22b5592c 100644 --- a/mp_api/client/mprester.py +++ b/mp_api/client/mprester.py @@ -21,9 +21,8 @@ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from requests import Session, get -from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env -from mp_api.client.core import BaseRester from mp_api.client.core._oxygen_evolution import OxygenEvolution +from mp_api.client.core.client import _Rester from mp_api.client.core.exceptions import ( MPRestError, MPRestWarning, @@ -33,7 +32,6 @@ from mp_api.client.core.utils import ( LazyImport, load_json, - validate_endpoint, validate_ids, ) from mp_api.client.routes import GENERIC_RESTERS @@ -45,10 +43,12 @@ from typing import Any, Literal import numpy as np + from deltalake import QueryBuilder from emmet.core.tasks import CoreTaskDoc from packaging.version import Version from pymatgen.analysis.phase_diagram import PDEntry from pymatgen.analysis.pourbaix_diagram import PourbaixEntry + from pymatgen.electronic_structure.dos import Dos from pymatgen.entries.compatibility import Compatibility from pymatgen.entries.computed_entries import ( ComputedEntry, @@ -85,14 +85,13 @@ ] -class MPRester: +class MPRester(_Rester): """Access the new Materials Project API.""" def __init__( self, api_key: str | None = None, endpoint: str | None = None, - notify_db_version: bool = False, include_user_agent: bool = True, use_document_model: bool = True, session: Session | None = None, @@ -102,6 +101,8 @@ def __init__( str | os.PathLike ) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE, force_renew: bool = False, + query_builder: QueryBuilder | None = None, + notify_db_version: bool = False, **kwargs, ): """Initialize the MPRester. @@ -118,13 +119,6 @@ def __init__( interface. Defaults to the standard Materials Project REST address at "https://api.materialsproject.org", but can be changed to other URLs implementing a similar interface. - notify_db_version (bool): If True, the current MP database version will - be retrieved and logged locally in the ~/.mprester.log.yaml. If the database - version changes, you will be notified. The current database version is - also printed on instantiation. These local logs are not sent to - materialsproject.org and are not associated with your API key, so be - aware that a notification may not be presented if you run MPRester - from multiple computing environments. include_user_agent (bool): If True, will include a user agent with the HTTP request including information on pymatgen and system version making the API request. This helps MP support pymatgen users, and @@ -139,25 +133,29 @@ def __init__( local_dataset_cache: Target directory for downloading full datasets. Defaults to "mp_datasets" in the user's home directory force_renew: Option to overwrite existing local dataset + query_builder : Instance of deltalake QueryBuilder to use in querying delta tables + notify_db_version (bool): If True, the current MP database version will + be retrieved and logged locally in the ~/.mprester.log.yaml. If the database + version changes, you will be notified. The current database version is + also printed on instantiation. These local logs are not sent to + materialsproject.org and are not associated with your API key, so be + aware that a notification may not be presented if you run MPRester + from multiple computing environments. **kwargs: access to legacy kwargs that may be in the process of being deprecated """ - self.api_key = get_user_api_key(api_key=api_key) - - self.endpoint = validate_endpoint(endpoint) - - self.headers = headers or get_consumer() - self.session = session or BaseRester._create_session( - api_key=self.api_key, + super().__init__( + api_key=api_key, + endpoint=endpoint, include_user_agent=include_user_agent, - headers=self.headers, + use_document_model=use_document_model, + session=session, + headers=headers, + mute_progress_bars=mute_progress_bars, + local_dataset_cache=local_dataset_cache, + force_renew=force_renew, + query_builder=query_builder, ) - if is_dev_env(): - self.session.headers["x-api-key"] = self.api_key or "" - self._include_user_agent = include_user_agent - self.use_document_model = use_document_model - self.mute_progress_bars = mute_progress_bars - self.local_dataset_cache = local_dataset_cache - self.force_renew = force_renew + self._contribs = None self._deprecated_attributes = [ @@ -190,14 +188,6 @@ def __init__( "chemenv", ] - if "monty_decode" in kwargs: - warnings.warn( - "Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`." - "The client by default returns results consistent with `monty_decode=True`.", - stacklevel=2, - category=MPRestWarning, - ) - # Check if emmet version of server is compatible if (emmet_version := MPRester.get_emmet_version(self.endpoint)) and ( version.parse(emmet_version.base_version) @@ -228,13 +218,14 @@ def __init__( lazy_rester( api_key=self.api_key, endpoint=self.endpoint, - include_user_agent=self._include_user_agent, + include_user_agent=self.include_user_agent, session=self.session, use_document_model=self.use_document_model, headers=self.headers, mute_progress_bars=self.mute_progress_bars, local_dataset_cache=self.local_dataset_cache, force_renew=self.force_renew, + query_builder=self._query_builder, ), ) @@ -269,14 +260,6 @@ def contribs(self): return self._contribs - def __enter__(self): - """Support for "with" context.""" - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """Support for "with" context.""" - self.session.close() - def __getattr__(self, attr): if attr in self._deprecated_attributes: warnings.warn( @@ -1141,18 +1124,34 @@ def get_bandstructure_by_material_id( material_id=material_id, path_type=path_type, line_mode=line_mode ) - def get_dos_by_material_id(self, material_id: str): - """Get the complete density of states pymatgen object associated with a Materials Project ID. + def get_dos_by_material_id(self, material_id: str) -> Dos: + """Get the density of states pymatgen object associated with a Materials Project ID. Arguments: material_id (str): Materials Project ID for a material Returns: - dos (CompleteDos): CompleteDos object + pymatgen Dos """ - return self.materials.electronic_structure_dos.get_dos_from_material_id( - material_id=material_id - ) # type: ignore + if ( + not ( + es_doc := self.materials.electronic_structure.search( + material_ids=material_id, fields=["dos"] + ) + ) + or not es_doc[0]["dos"] + ): + raise MPRestError(f"No DOS found for {material_id}") + + dos_data = es_doc[0]["dos"] + task_id = dos_data.task_id if self.use_document_model else dos_data["task_id"] + run_type = self.materials.tasks.search(task_ids=[task_id], fields=["run_type"])[ + 0 + ]["run_type"] + return self.materials.electronic_structure_dos.get_dos_from_task_id( + task_id, + run_type=run_type, + ) def get_phonon_dos_by_material_id(self, material_id: str): """Get phonon density of states data corresponding to a material_id. diff --git a/mp_api/client/routes/materials/doi.py b/mp_api/client/routes/materials/doi.py index c55e3758..26b268ca 100644 --- a/mp_api/client/routes/materials/doi.py +++ b/mp_api/client/routes/materials/doi.py @@ -12,6 +12,7 @@ class DOIRester(BaseRester): suffix = "doi" document_model = DOIDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/electronic_structure.py b/mp_api/client/routes/materials/electronic_structure.py index b0bee09e..87bb40b7 100644 --- a/mp_api/client/routes/materials/electronic_structure.py +++ b/mp_api/client/routes/materials/electronic_structure.py @@ -4,20 +4,23 @@ from collections import defaultdict from typing import TYPE_CHECKING -from emmet.core.band_theory import BSPathType -from emmet.core.electronic_structure import ( - DOSProjectionType, - ElectronicStructureDoc, -) +from emmet.core.band_theory import BSPathType, ElectronicBS, ElectronicDos +from emmet.core.electronic_structure import DOSProjectionType, ElectronicStructureDoc +from emmet.core.mpid import AlphaID +from emmet.core.vasp.calc_types.enums import RunType from pymatgen.analysis.magnetism.analyzer import Ordering from pymatgen.core.periodic_table import Element +from pymatgen.electronic_structure.bandstructure import ( + BandStructure, + BandStructureSymmLine, +) from pymatgen.electronic_structure.core import OrbitalType, Spin from mp_api.client.core import BaseRester, MPRestError -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import validate_ids if TYPE_CHECKING: - from pymatgen.electronic_structure.dos import CompleteDos + from pymatgen.electronic_structure.dos import Dos class ElectronicStructureRester(BaseRester): @@ -167,6 +170,7 @@ def es_rester(self) -> ElectronicStructureRester: class BandStructureRester(BaseESPropertyRester): suffix = "materials/electronic_structure/bandstructure" + delta_backed = False def search_bandstructure_summary(self, *args, **kwargs): # pragma: no cover """Deprecated.""" @@ -255,20 +259,51 @@ def search( **query_params, ) - def get_bandstructure_from_task_id(self, task_id: str): + def get_bandstructure_from_task_id( + self, + task_id: str, + run_type: str | RunType | None = None, + path_type: str | BSPathType | None = None, + ) -> BandStructure: """Get the band structure pymatgen object associated with a given task ID. Arguments: task_id (str): Task ID for the band structure calculation - + run_type (str, RunType, or None): Optional run type, + will speed up query due to delta table partitioning. + path_type (str, BSPathType, or None) : Optional path type to + speed up query Returns: bandstructure (BandStructure): BandStructure or BandStructureSymmLine object """ - return self._query_open_data( # type: ignore[call-overload] - bucket="materialsproject-parsed", - key=f"bandstructures/{validate_ids([task_id])[0]}.json.gz", - decoder=lambda x: load_json(x, deser=True), - )[0][0]["data"] + bs_lbl, bs_tbl = self._get_delta_table( + "materialsproject-parsed", + "core/electronic-structure/bandstructures/", + label="bandstructure", + ) + + query = f""" + SELECT * + FROM {bs_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' + """ + + if run_type: + rt = RunType(run_type) if isinstance(run_type, str) else run_type + query += f"\nAND run_type='{rt.value}'" + if path_type: + query += f"\nAND path_convention='{path_type}'" + + table = self._query_delta_single(query) + if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: + emmet_bs = ElectronicBS(**deser[0]) + return emmet_bs.to_pmg( + pmg_cls=BandStructureSymmLine if emmet_bs.labels_dict else BandStructure + ) + raise MPRestError( + f"No bandstructure data found for {task_id=}" + + (f"run_type={rt}" if run_type else "") + ) def get_bandstructure_from_material_id( self, @@ -291,7 +326,9 @@ def get_bandstructure_from_material_id( material_ids=material_id, fields=["bandstructure"] ) if not bs_doc: - raise MPRestError("No electronic structure data found.") + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if (_bs_data := bs_doc[0]["bandstructure"]) is None: raise MPRestError( @@ -314,7 +351,9 @@ def get_bandstructure_from_material_id( material_ids=material_id, fields=["dos"] ) ): - raise MPRestError("No electronic structure data found.") + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if (_bs_data := bs_doc[0]["dos"]) is None: raise MPRestError( @@ -329,7 +368,10 @@ def get_bandstructure_from_material_id( ) bs_task_id = bs_data["total"]["1"]["task_id"] - bs_obj = self.get_bandstructure_from_task_id(bs_task_id) + bs_obj = self.get_bandstructure_from_task_id( + bs_task_id, + path_type=path_type if line_mode else BSPathType.unknown, + ) if bs_obj: return bs_obj @@ -338,6 +380,7 @@ def get_bandstructure_from_material_id( class DosRester(BaseESPropertyRester): suffix = "materials/electronic_structure/dos" + delta_backed = False def search_dos_summary(self, *args, **kwargs): # pragma: no cover """Deprecated.""" @@ -451,42 +494,62 @@ def search( **query_params, ) - def get_dos_from_task_id(self, task_id: str) -> CompleteDos: + def get_dos_from_task_id( + self, task_id: str, run_type: str | RunType | None = None + ) -> Dos: """Get the density of states pymatgen object associated with a given calculation ID. Arguments: task_id (str): Task ID for the density of states calculation + run_type (str, RunType, or None): Optional run type to query by. + Will speed up query due to delta table partitioning. Returns: - bandstructure (CompleteDos): CompleteDos object + pymatgen Dos + """ + dos_lbl, dos_tbl = self._get_delta_table( + "materialsproject-parsed", + "core/electronic-structure/total-dos/", + label="total_dos", + ) + + query = f""" + SELECT * + FROM {dos_lbl} + WHERE identifier='{str(AlphaID(task_id.split("-")[-1],padlen=8))}' """ - return self._query_open_data( # type: ignore[call-overload] - bucket="materialsproject-parsed", - key=f"dos/{validate_ids([task_id])[0]}.json.gz", - decoder=lambda x: load_json(x, deser=True), - )[0][0]["data"] - def get_dos_from_material_id(self, material_id: str): + if run_type: + rt = RunType(run_type) if isinstance(run_type, str) else run_type + query += f"\nAND run_type='{rt.value}'" + + table = self._query_delta_single(query) + if len(deser := table.to_pylist(maps_as_pydicts="strict")) > 0: + return ElectronicDos(**deser[0]).to_pmg() + raise MPRestError( + f"No DOS data found for {task_id=}" + (f"run_type={rt}" if run_type else "") + ) + + def get_dos_from_material_id(self, material_id: str) -> Dos: """Get the complete density of states pymatgen object associated with a Materials Project ID. Arguments: material_id (str): Materials Project ID for a material Returns: - dos (CompleteDos): CompleteDos object + pymatgen Dos """ if not ( dos_doc := self.es_rester.search(material_ids=material_id, fields=["dos"]) ): - return None + raise MPRestError( + f"No electronic structure data found for material ID {material_id}." + ) if not (dos_data := dos_doc[0].get("dos")): raise MPRestError(f"No density of states data found for {material_id}") dos_task_id = (dos_data.model_dump() if self.use_document_model else dos_data)[ - "total" - ]["1"]["task_id"] - if dos_obj := self.get_dos_from_task_id(dos_task_id): - return dos_obj - - raise MPRestError("No density of states object found.") + "task_id" + ] + return self.get_dos_from_task_id(dos_task_id) diff --git a/mp_api/client/routes/materials/eos.py b/mp_api/client/routes/materials/eos.py index 0182eb6f..2300db94 100644 --- a/mp_api/client/routes/materials/eos.py +++ b/mp_api/client/routes/materials/eos.py @@ -1,32 +1,34 @@ from __future__ import annotations +import warnings from collections import defaultdict from emmet.core.eos import EOSDoc -from mp_api.client.core import BaseRester +from mp_api.client.core import BaseRester, MPRestError, MPRestWarning from mp_api.client.core.utils import validate_ids class EOSRester(BaseRester): suffix = "materials/eos" document_model = EOSDoc # type: ignore - primary_key = "material_id" + primary_key = "task_id" def search( self, - material_ids: str | list[str] | None = None, + task_ids: str | list[str] | None = None, energies: tuple[float, float] | None = None, volumes: tuple[float, float] | None = None, num_chunks: int | None = None, chunk_size: int = 1000, all_fields: bool = True, fields: list[str] | None = None, + **kwargs, ) -> list[EOSDoc] | list[dict]: """Query equations of state docs using a variety of search criteria. Arguments: - material_ids (str, List[str]): Search for equation of states associated with the specified Material IDs + task_ids (str, List[str]): Search for equation of states associated with the specified task IDs energies (Tuple[float,float]): Minimum and maximum energy in eV/atom to consider for EOS plot range. volumes (Tuple[float,float]): Minimum and maximum volume in A³/atom to consider for EOS plot range. num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. @@ -34,17 +36,31 @@ def search( all_fields (bool): Whether to return all fields in the document. Defaults to True. fields (List[str]): List of fields in EOSDoc to return data for. Default is material_id only if all_fields is False. + **kwargs : used for handling deprecated kwargs Returns: ([EOSDoc], [dict]) List of equations of state docs or dictionaries. """ query_params: dict = defaultdict(dict) - if material_ids: - if isinstance(material_ids, str): - material_ids = [material_ids] + if "material_ids" in kwargs: + if task_ids: + raise MPRestError( + "You have specified both `task_ids` and the deprecated `material_ids` tag. " + "Please specify only `task_ids`." + ) + task_ids = kwargs.pop("material_ids") + warnings.warn( + "`material_id` has been replaced by `task_id` in the EOS endpoint. " + "Please migrate to using the newer field name.", + stacklevel=2, + category=MPRestWarning, + ) - query_params.update({"material_ids": ",".join(validate_ids(material_ids))}) + if task_ids: + query_params["material_ids"] = ",".join( + validate_ids([task_ids] if isinstance(task_ids, str) else task_ids) + ) if volumes: query_params.update({"volumes_min": volumes[0], "volumes_max": volumes[1]}) diff --git a/mp_api/client/routes/materials/grain_boundaries.py b/mp_api/client/routes/materials/grain_boundaries.py index 6949b9de..d9ac75c3 100644 --- a/mp_api/client/routes/materials/grain_boundaries.py +++ b/mp_api/client/routes/materials/grain_boundaries.py @@ -12,6 +12,7 @@ class GrainBoundaryRester(BaseRester): suffix = "materials/grain_boundaries" document_model = GrainBoundaryDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/phonon.py b/mp_api/client/routes/materials/phonon.py index 0373cd0d..c3d9db45 100644 --- a/mp_api/client/routes/materials/phonon.py +++ b/mp_api/client/routes/materials/phonon.py @@ -18,6 +18,7 @@ class PhononRester(BaseRester): suffix = "materials/phonon" document_model = PhononBSDOSDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/similarity.py b/mp_api/client/routes/materials/similarity.py index aa6cab71..0ba8c5b7 100644 --- a/mp_api/client/routes/materials/similarity.py +++ b/mp_api/client/routes/materials/similarity.py @@ -26,6 +26,7 @@ class SimilarityRester(BaseRester): suffix = "materials/similarity" document_model = SimilarityDoc # type: ignore primary_key = "material_id" + delta_backed = False _fingerprinter: SimilarityScorer | None = None diff --git a/mp_api/client/routes/materials/substrates.py b/mp_api/client/routes/materials/substrates.py index 62eaa676..6f1096b1 100644 --- a/mp_api/client/routes/materials/substrates.py +++ b/mp_api/client/routes/materials/substrates.py @@ -11,6 +11,7 @@ class SubstratesRester(BaseRester): suffix = "materials/substrates" document_model = SubstratesDoc # type: ignore primary_key = "film_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/surface_properties.py b/mp_api/client/routes/materials/surface_properties.py index 76d9e60c..3a36d5f9 100644 --- a/mp_api/client/routes/materials/surface_properties.py +++ b/mp_api/client/routes/materials/surface_properties.py @@ -12,6 +12,7 @@ class SurfacePropertiesRester(BaseRester): suffix = "materials/surface_properties" document_model = SurfacePropDoc # type: ignore primary_key = "material_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/synthesis.py b/mp_api/client/routes/materials/synthesis.py index 6788814c..4567c51f 100644 --- a/mp_api/client/routes/materials/synthesis.py +++ b/mp_api/client/routes/materials/synthesis.py @@ -12,6 +12,7 @@ class SynthesisRester(BaseRester): suffix = "materials/synthesis" document_model = SynthesisSearchResultModel # type: ignore + delta_backed = False def search( self, diff --git a/mp_api/client/routes/materials/tasks.py b/mp_api/client/routes/materials/tasks.py index 66a03758..c55aeb85 100644 --- a/mp_api/client/routes/materials/tasks.py +++ b/mp_api/client/routes/materials/tasks.py @@ -3,8 +3,6 @@ from datetime import datetime from typing import TYPE_CHECKING -import pyarrow as pa -from deltalake import DeltaTable, QueryBuilder from emmet.core.mpid import MPID, AlphaID from emmet.core.tasks import CoreTaskDoc from emmet.core.trajectory import RelaxTrajectory @@ -40,23 +38,23 @@ def get_trajectory( dict representing emmet.core.trajectory.RelaxTrajectory """ as_alpha = str(AlphaID(task_id, padlen=8)).split("-")[-1] - predicate = ( - f"WHERE run_type='{str(run_type)}' AND identifier='{as_alpha}'" - if run_type - else f"WHERE identifier='{as_alpha}'" - ) + f"WHERE run_type='{str(run_type)}' AND " if run_type else "" + ) + f"WHERE identifier='{as_alpha}'" - traj_tbl = DeltaTable( - "s3a://materialsproject-parsed/core/trajectories/", - storage_options={"AWS_SKIP_SIGNATURE": "true", "AWS_REGION": "us-east-1"}, + traj_lbl, traj_tbl = self._get_delta_table( + "materialsproject-parsed", + "core/trajectories/", + label="traj", ) - traj_data = pa.table(QueryBuilder().register("traj", traj_tbl).execute(f""" - SELECT * - FROM traj - {predicate}; - """).read_all()).to_pylist(maps_as_pydicts="strict") + query = f""" + SELECT * + FROM {traj_lbl} + {predicate}; + """ + + traj_data = self._query_delta_single(query).to_pylist(maps_as_pydicts="strict") if not traj_data: raise MPRestError(f"No trajectory data for {task_id} found") diff --git a/mp_api/client/routes/materials/thermo.py b/mp_api/client/routes/materials/thermo.py index a2088a7f..9f2f82b5 100644 --- a/mp_api/client/routes/materials/thermo.py +++ b/mp_api/client/routes/materials/thermo.py @@ -5,12 +5,13 @@ import numpy as np from emmet.core.thermo import ThermoDoc from emmet.core.types.enums import ThermoType +from emmet.core.types.pymatgen_types.phase_diagram_adapter import PhaseDiagramType +from pydantic import TypeAdapter from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.core import Element -from pymatgen.core import __version__ as __pmg_version__ from mp_api.client.core import BaseRester -from mp_api.client.core.utils import load_json, validate_ids +from mp_api.client.core.utils import validate_ids class ThermoRester(BaseRester): @@ -164,24 +165,26 @@ def get_phase_diagram_from_chemsys( ) sorted_chemsys = "-".join(sorted(chemsys.split("-"))) - phdiag_id = f"thermo_type={t_type}/chemsys={sorted_chemsys}" - version = self.db_version.replace(".", "-") - obj_key = f"objects/{version}/phase-diagrams/{phdiag_id}.jsonl.gz" - pd_dct = self._query_open_data( # type: ignore[union-attr] - bucket="materialsproject-build", - key=obj_key, - decoder=lambda x: load_json(x, deser=False), - )[0][0].get("phase_diagram") - - pd = PhaseDiagram.from_dict( - { # type: ignore[arg-type] - k: v if k != "elements" else [e.get("element", e) for e in v] - for k, v in pd_dct.items() # type: ignore[union-attr] - } # post pymatgen/-core split, different serialization behavior - if int(__pmg_version__.split(".", 1)[0]) >= 2026 - else pd_dct # pymatgen<=2025.10.7 + version = "2026-04-13" # self.db_version.replace(".", "-") + + pd_lbl, pd_tbl = self._get_delta_table( + "materialsproject-build", "objects/phase-diagrams", label="phase_diagrams" ) + query = f""" + SELECT phase_diagram + FROM {pd_lbl} + WHERE chemsys='{sorted_chemsys}' + AND version='{version}' + AND thermo_type='{thermo_type}' + """ + table = self._query_delta_single(query) + as_py = table["phase_diagram"].to_pylist(maps_as_pydicts="strict") + + pd: PhaseDiagram | None = None + if len(pds := TypeAdapter(list[PhaseDiagramType]).validate_python(as_py)) > 0: + pd = pds[0] + # Ensure el_ref keys are Element objects for PDPlotter. # Ensure qhull_data is a numpy array # This should be fixed in pymatgen diff --git a/mp_api/client/routes/materials/xas.py b/mp_api/client/routes/materials/xas.py index a4f164f8..8accb853 100644 --- a/mp_api/client/routes/materials/xas.py +++ b/mp_api/client/routes/materials/xas.py @@ -17,6 +17,7 @@ class XASRester(BaseRester): suffix = "materials/xas" document_model = XASDoc # type: ignore primary_key = "spectrum_id" + delta_backed = False def search( self, diff --git a/mp_api/client/routes/molecules/jcesr.py b/mp_api/client/routes/molecules/jcesr.py index 2d462c19..24d3f5e6 100644 --- a/mp_api/client/routes/molecules/jcesr.py +++ b/mp_api/client/routes/molecules/jcesr.py @@ -15,6 +15,7 @@ class JcesrMoleculesRester(BaseRester): suffix = "molecules/jcesr" document_model = MoleculesDoc # type: ignore primary_key = "task_id" + delta_backed = False def __init__(self, **kwargs): """Throw deprecation warning when JCESR client is initialized.""" diff --git a/mp_api/client/routes/molecules/molecules.py b/mp_api/client/routes/molecules/molecules.py index b7600328..3171b55c 100644 --- a/mp_api/client/routes/molecules/molecules.py +++ b/mp_api/client/routes/molecules/molecules.py @@ -20,3 +20,4 @@ class MoleculeRester(CoreRester): primary_key = "molecule_id" suffix = "molecules/core" _sub_resters = MOLECULES_RESTERS + delta_backed = False diff --git a/mp_api/client/routes/molecules/summary.py b/mp_api/client/routes/molecules/summary.py index 4be3aab5..2f91677e 100644 --- a/mp_api/client/routes/molecules/summary.py +++ b/mp_api/client/routes/molecules/summary.py @@ -12,6 +12,7 @@ class MoleculesSummaryRester(BaseRester): suffix = "molecules/summary" document_model = MoleculeSummaryDoc # type: ignore primary_key = "molecule_id" + delta_backed = False def search( self, diff --git a/pyproject.toml b/pyproject.toml index 45d7a813..3742f3bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "typing-extensions>=3.7.4.1", "requests>=2.23.0", "monty>=2024.12.10", - "emmet-core>=0.86.4rc1,<0.86.5", + "emmet-core>=0.87.0.dev,<0.87.2", "boto3", "orjson >= 3.10,<4", "pyarrow >= 20.0.0", @@ -37,7 +37,7 @@ mcp = ["fastmcp"] server = ["flask"] all = [ "custodian", - "emmet-core[all]>=0.86.4rc1,<0.86.5", + "emmet-core[all]>=0.87.0.dev,<0.87.2", "fastmcp", "flask", "mpcontribs-client>=5.10", diff --git a/tests/client/materials/test_electronic_structure.py b/tests/client/materials/test_electronic_structure.py index a89cc730..6b4b37c6 100644 --- a/tests/client/materials/test_electronic_structure.py +++ b/tests/client/materials/test_electronic_structure.py @@ -104,7 +104,7 @@ def test_bs_client(): with pytest.raises(MPRestError, match="No electronic structure data found."): _ = bs_rester.get_bandstructure_from_material_id("mp-0") - with pytest.raises(MPRestError, match="No object found"): + with pytest.raises(MPRestError, match="No bandstructure data found"): _ = bs_rester.get_bandstructure_from_task_id("mp-0") diff --git a/tests/client/materials/test_eos.py b/tests/client/materials/test_eos.py index 3e633e49..e71fc010 100644 --- a/tests/client/materials/test_eos.py +++ b/tests/client/materials/test_eos.py @@ -4,6 +4,7 @@ from mp_api._test_utils import client_search_testing, requires_api_key +from mp_api.client.core.exceptions import MPRestError, MPRestWarning from mp_api.client.routes.materials.eos import EOSRester @@ -26,9 +27,9 @@ def rester(): sub_doc_fields: list = [] -alt_name_dict: dict = {"material_ids": "material_id"} +alt_name_dict: dict = {"task_ids": "task_id"} -custom_field_tests: dict = {"material_ids": ["mp-149"]} +custom_field_tests: dict = {"task_ids": ["mp-149"]} @requires_api_key @@ -42,3 +43,15 @@ def test_client(rester): custom_field_tests=custom_field_tests, sub_doc_fields=sub_doc_fields, ) + + +@requires_api_key +def test_warnings_errors(rester): + + with pytest.warns( + MPRestWarning, match="`material_id` has been replaced by `task_id`" + ): + rester.search(material_ids=["mp-149"], num_chunks=1, chunk_size=1) + + with pytest.raises(MPRestError, match="You have specified both"): + rester.search(material_ids=["mp-149"], task_ids=["mp-1"])