Source code for matbench_genmetrics.mp_time_split.splitter

"""Core functionality for Materials Project time-based train/test splitting"""

import argparse
import logging
import sys
from hashlib import md5
from os import environ, path
from pathlib import Path
from shutil import move
from typing import List, Optional, Tuple, Union
from urllib.request import urlretrieve

import pandas as pd
import pybtex.errors
from matminer.utils.io import load_dataframe_from_json
from typing_extensions import Literal

from matbench_genmetrics.mp_time_split import __version__
from matbench_genmetrics.mp_time_split.utils.data import (
    DUMMY_SNAPSHOT_NAME,
    SNAPSHOT_NAME,
)
from matbench_genmetrics.mp_time_split.utils.split import (
    AVAILABLE_MODES,
    mp_time_splitter,
)

pybtex.errors.set_strict_mode(False)

__author__ = "sgbaird"
__copyright__ = "sgbaird"
__license__ = "MIT"
_logger = logging.getLogger(__name__)


FOLDS = [0, 1, 2, 3, 4]
dummy_checksum_frozen = "6bf42266bd71477a06b24153d4ff7889"
full_checksum_frozen = "57da7fa4d96ffbbc0dd359b1b7423f31"


[docs] def get_data_home(data_home=None): """ Selects the home directory to look for datasets, if the specified home directory doesn't exist the directory structure is built Modified from source: https://github.com/hackingmaterials/matminer/blob/76a529b769055c729d62f11a419d319d8e2f838e/matminer/datasets/utils.py#L26-L43 # noqa:E501 Args: data_home (str): folder to look in, if None a default is selected Returns (str) """ # If user doesn't specify a dataset directory: first check for env var, then default # to the "matminer/datasets/" package folder if data_home is None: data_home = environ.get( "MP_TIME_DATA", path.join(path.dirname(path.abspath(__file__)), "utils") ) data_home = path.expanduser(data_home) return data_home
[docs] class MPTimeSplit: def __init__( self, num_sites: Optional[Tuple[int, int]] = None, elements: Optional[List[str]] = None, exclude_elements: Optional[ Union[List[str], Literal["noble", "radioactive", "noble+radioactive"]] ] = None, use_theoretical: bool = False, mode: str = "TimeSeriesSplit", target: str = "energy_above_hull", save_dir=None, ) -> None: """Top-level class for Materials Project time-based train/test splitting. Parameters ---------- num_sites : Optional[Tuple[int, int]], optional Range of number of sites fetched from MP. If None, then no restrictions on sites, by default None elements : Optional[List[str]], optional Allowed elements for data fetching from MP. If None, no restrictions, by default None exclude_elements : Optional[ Union[List[str], Literal["noble", "radioactive", "noble+radioactive"]], optional Elements to be excluded. Options are "noble" (noble gases), "radioactive" (radioactive elements), and "noble+radioactive" (both noble gases and radioactive elements), by default None use_theoretical : bool, optional Whether to include theoretical compounds from MP, by default False mode : str, optional The splitter type, can be one of "TimeSeriesSplit", "TimeSeriesOverflowSplit", "TimeKFold", by default "TimeSeriesSplit" target : str, optional a property to also include in the DataFrame, by default "energy_above_hull" save_dir : str, optional The directory to save the data to. If None, then uses get_data_home(). By default None. Raises ------ NotImplementedError "mode={mode} not implemented. Use one of {AVAILABLE_MODES} Examples -------- >>> mpts = MPTimeSplit( ... num_sites=None, ... elements=None, ... exclude_elements=None, ... use_theoretical=False, ... mode="TimeSeriesSplit", ... target="energy_above_hull", ... save_dir=None, ... ) >>> mpts.load() >>> """ if mode not in AVAILABLE_MODES: raise NotImplementedError( f"mode={mode} not implemented. Use one of {AVAILABLE_MODES}" ) self.num_sites = num_sites self.elements = elements self.exclude_elements = exclude_elements self.use_theoretical = use_theoretical self.mode = mode self.folds = FOLDS if save_dir is None: self.save_dir = get_data_home() else: self.save_dir = save_dir Path(self.save_dir).mkdir(exist_ok=True, parents=True) self.target = target
[docs] def fetch_data(self, one_by_one=False): """Fetch data directly from Materials Project and split into train/test sets. Parameters ---------- one_by_one : bool, optional Whether to retrieve data one-by-one instead of in bulk. This is useful for (since need provenance attributes). By default False. Returns ------- df : pd.DataFrame Dataframe of Materials Project data containing `structure` and `target` columns. `structure` is of type `pymatgen.core.structure.Structure`. Raises ------ ImportError Failed to import `fetch_data()`. Try `pip install mp_time_split[api]` or `pip install mp-api` to install the optional `mp-api` dependency. Note that this requires Python >=3.8 ValueError `self.data` is not a `pd.DataFrame` Examples -------- >>> mpts = MPTimeSplit() >>> mpts.fetch_data() """ try: from matbench_genmetrics.mp_time_split.utils.api import fetch_data except ImportError as e: raise ImportError( "Failed to import `fetch_data()`. Try `pip install mp_time_split[api]` or `pip install mp-api` to install the optional `mp-api` dependency. Note that this requires Python >=3.8" # noqa: E501 ) from e self.data = fetch_data( num_sites=self.num_sites, elements=self.elements, exclude_elements=self.exclude_elements, use_theoretical=self.use_theoretical, one_by_one=one_by_one, ) if not isinstance(self.data, pd.DataFrame): raise ValueError("`self.data` is not a `pd.DataFrame`") self.trainval_splits, self.test_split = mp_time_splitter( self.data, n_cv_splits=len(FOLDS), mode=self.mode ) self.inputs = self.data.structure self.outputs = getattr(self.data, self.target) return self.data
[docs] def load(self, url=None, checksum=None, dummy=False, force_download=False): """Load data from an existing snapshot. Parameters ---------- url : str, optional URL to download the data from, by default None checksum : str, optional Checksum to ensure the validity of the file, by default None dummy : bool, optional Whether to load a dummy snapshot or not, by default False force_download : bool, optional Whether to force download, regardless of whether the data has already been downloaded, by default False Returns ------- pd.DataFrame DataFrame of Materials Project data containing `structure` and `target` columns. `structure` is of type `pymatgen.core.structure.Structure`. Raises ------ ValueError url should not be None at this point. url: {url}, type: {type(url)} ValueError checksum from {url} ({checksum}) does not match what was expected {checksum_frozen}) Examples -------- >>> mpts = MPTimeSplit() >>> mpts.load(url=None, checksum=None, dummy=False, force_download=False) """ name = SNAPSHOT_NAME if not dummy else DUMMY_SNAPSHOT_NAME name = name + ".gz" data_path = path.join(self.save_dir, name) is_on_disk = Path(data_path).is_file() if force_download or not is_on_disk: if dummy and url is None and checksum is None: # dummy data from figshare for testing url = "https://figshare.com/ndownloader/files/35592005" checksum_frozen = dummy_checksum_frozen elif not dummy and url is None and checksum is None: # full dataset from figshare for production url = "https://figshare.com/ndownloader/files/35592011" checksum_frozen = full_checksum_frozen elif url is None: raise ValueError( f"url should not be None at this point. url: {url}, type: {type(url)}" # noqa: E501 ) else: checksum_frozen = None # download to temp file in case interrupted partway data_path_tmp = data_path + "tmp" urlretrieve(url, data_path_tmp) move(data_path_tmp, data_path) else: checksum_frozen = None checksum = md5(Path(data_path).read_bytes()).hexdigest() if checksum_frozen is not None and checksum != checksum_frozen: raise ValueError( f"checksum from {url} ({checksum}) does not match what was expected {checksum_frozen})" # noqa: E501 ) expt_df = load_dataframe_from_json(data_path) self.data = expt_df self.trainval_splits, self.test_split = mp_time_splitter( self.data, n_cv_splits=len(FOLDS), mode=self.mode ) self.inputs = self.data.structure self.outputs = getattr(self.data, self.target) return self.data
[docs] def get_train_and_val_data(self, fold): """Get training and validation data for a given fold. Parameters ---------- fold : int The cross-validation fold to get the data for. Returns ------- pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame Input training data, input validation data, output training data, and output validation data. Note that the input data is a `pymatgen.core.structure.Structure` object. Raises ------ NameError `fetch_data()` or `load()` must be run first. ValueError fold={fold} should be one of {FOLDS} Examples -------- >>> mpts = MPTimeSplit() >>> mpts.get_train_and_val_data(0) """ if self.data is None: raise NameError("`fetch_data()` or `load()` must be run first.") if fold not in FOLDS: raise ValueError(f"fold={fold} should be one of {FOLDS}") # self.y = self.data[] train_and_val_inputs, test_inputs = [ self.inputs.iloc[tvs] for tvs in self.trainval_splits[fold] ] train_and_val_outputs, test_outputs = [ self.outputs.iloc[tvs] for tvs in self.trainval_splits[fold] ] return train_and_val_inputs, test_inputs, train_and_val_outputs, test_outputs
[docs] def get_final_test_data(self): """ The 'for real life' test split, i.e., what gets touched only once before submitting a manuscript. """ if self.data is None: raise NameError("`fetch_data()` or or `load()` must be run first.") train_inputs, test_inputs = [self.inputs.iloc[ts] for ts in self.test_split] train_outputs, test_outputs = [self.outputs.iloc[ts] for ts in self.test_split] return train_inputs, test_inputs, train_outputs, test_outputs
# ---- CLI ---- The functions defined in this section are wrappers around the main # Python API allowing them to be called directly from the terminal as a CLI # executable/script.
[docs] def parse_args(args): """Parse command line parameters Args: args (List[str]): command line parameters as list of strings (for example ``["--help"]``). Returns: :obj:`argparse.Namespace`: command line parameters namespace """ parser = argparse.ArgumentParser( description="For downloading mp-time-split snapshot." ) parser.add_argument( "--version", action="version", version=f"$mp_time_split {__version__}", ) parser.add_argument( "-s", "--save-dir", dest="save_dir", default=".", help="Directory in which to save json.gz snapshot.", type=str, metavar="STRING", ) parser.add_argument( "-v", "--verbose", dest="loglevel", help="set loglevel to INFO", action="store_const", const=logging.INFO, ) parser.add_argument( "-vv", "--very-verbose", dest="loglevel", help="set loglevel to DEBUG", action="store_const", const=logging.DEBUG, ) return parser.parse_args(args)
[docs] def setup_logging(loglevel): """Setup basic logging Args: loglevel (int): minimum loglevel for emitting messages """ logformat = "[%(asctime)s] %(levelname)s:%(name)s:%(message)s" logging.basicConfig( level=loglevel, stream=sys.stdout, format=logformat, datefmt="%Y-%m-%d %H:%M:%S" )
[docs] def main(args): """Wrapper allowing calls with string arguments in a CLI fashion ``stdout`` in a nicely formatted message. Args: args (List[str]): command line parameters as list of strings (for example ``["--verbose", "./data"]``). """ args = parse_args(args) setup_logging(args.loglevel) _logger.debug("Beginning download of mp-time-split snapshot") mpt = MPTimeSplit(save_dir=args.save_dir) mpt.load() _logger.info(f"The snapshot is saved at {args.save_dir}")
[docs] def run(): """Calls :func:`main` passing the CLI arguments extracted from :obj:`sys.argv` This function can be used as entry point to create console scripts with setuptools. """ main(sys.argv[1:])
if __name__ == "__main__": # ^ This is a guard statement that will prevent the following code from being # executed in the case someone imports this file instead of executing it as a # script. https://docs.python.org/3/library/__main__.html After installing your # project with pip, users can also run your Python modules as scripts via the # ``-m`` flag, as defined in PEP 338:: # # python -m mp_time_split.core 42 # run()