Source code for pycmtensor.config

# config.py
"""PyCMTensor config module"""
import configparser
import glob
import multiprocessing
import os
import subprocess
import sys

import numpy as np
from watermark import watermark

from .logger import info, warning
from .scheduler import *

__all__ = ["Config"]


def about():
    return watermark(
        python=True,
        datename=True,
        updated=True,
        packages="pycmtensor,aesara,numpy,scipy",
        machine=True,
    )


def generate_blas_flags():
    """Finds and generates the blas_flags config option for .aesararc"""
    if "CONDA_PREFIX" not in os.environ:
        raise NameError("CONDA_PREFIX not found in environment variables")

    if sys.platform == "win32":
        ld_dir = os.path.join(os.getenv("CONDA_PREFIX"), "Library", "bin")
    else:
        ld_dir = os.path.join(os.getenv("CONDA_PREFIX"), "lib")

    flags = [f"-L{ld_dir}"]
    mkt_rt_bins = glob.glob(os.path.join(ld_dir, "*mkl_rt*"))
    for b in mkt_rt_bins:
        if sys.platform != "win32":
            b = "mkl_rt"
        if b.endswith(".dll"):
            b = b[:-4]
        mkl_flags = [f"-l{os.path.basename(b)}"]
    flags.extend(mkl_flags)
    return flags


def generate_cxx_flags_macos():
    """Finds and generates the gcc__cxxflags config option in macos"""
    sdk_path = subprocess.getoutput("xcrun --show-sdk-path")
    cxxflags = os.path.join(sdk_path, "usr", "include")
    return [f"-I{cxxflags}"]


def init_aesara_rc():
    """Generates the .aesararc config file in the user's home directory"""
    HOMEPATH = os.path.expanduser("~")
    conf_file = os.path.join(HOMEPATH, ".aesararc")
    aesara_rc = configparser.ConfigParser()

    # section global
    aesara_rc.add_section("global")
    aesara_rc["global"] = {"device": "cpu", "floatX": "float64"}
    aesara_rc["global"]["allow_gc"] = "False"
    aesara_rc["global"]["on_unused_input"] = "ignore"
    aesara_rc["global"]["openmp"] = "True"
    aesara_rc["global"]["optimizer"] = "fast_compile"
    aesara_rc["global"]["cycle_detection"] = "fast"
    aesara_rc["global"]["optimizer_including"] = "local_remove_all_assert"

    # section blas
    aesara_rc.add_section("blas")
    ldflags = " ".join(f"{flag}" for flag in generate_blas_flags())
    aesara_rc["blas"]["ldflags"] = ldflags

    # section gcc
    aesara_rc.add_section("gcc")

    if sys.platform == "darwin":
        flags = " ".join(f"{flag}" for flag in generate_cxx_flags_macos())
        aesara_rc["gcc"]["cxxflags"] = flags
    else:
        aesara_rc["gcc"]["cxxflags"] = " ".join([""])
        # ["-03", "-ffast-math", "-funroll-loops", "-ftracer"]

    with open(conf_file, "w") as f:
        aesara_rc.write(f)

    return aesara_rc


def init_environment_variables():
    """Sets misc. options in the environment variables"""
    num_cores = multiprocessing.cpu_count()
    os.environ["MKL_NUM_THREADS"] = str(num_cores)
    os.environ["OMP_NUM_THREADS"] = str(num_cores)


[docs]class Config: def __init__(self): """Class object to store config and hyperparameters""" self.rng = np.random.default_rng() self.info = { "python_version": sys.version, "directory": os.getcwd(), } self.hyperparameters = { "seed": int(self.rng.integers(1, 9000)), "patience": 4000, "patience_increase": 2, "validation_threshold": 1.005, "base_learning_rate": 0.01, "max_learning_rate": None, "batch_size": 250, "max_steps": 1000, "batch_shuffle": False, "lr_scheduler": ConstantLR(), "clr_gamma": None, "clr_cycle_steps": 16, } self.aesara_rc = init_aesara_rc() init_environment_variables() def __print__(self): rval = f"\nhyperparameters\n---------------\n" for key, val in self.hyperparameters.items(): rval += f"{key:<24}: {val}\n" return rval def __getitem__(self, name: str): if name in self.hyperparameters: return self.hyperparameters[name] def __setitem__(self, name: str, val): if name in self.hyperparameters: self.hyperparameters[name] = val else: raise NameError(f"hyperparameter {name} is not a valid option") def __call__(self): return self.hyperparameters
[docs] def set_hyperparameter(self, **kwargs): """function to set hyperparameters""" for key, value in kwargs.items(): if key not in self.hyperparameters: warning(f"{key} not a valid training hyperparameter, skipping.") else: self.hyperparameters[key] = value info(f"Hyperparameter set {key}={value}") self.check_values()
[docs] def set_lr_scheduler(self, scheduler): """Sets the config option ``lr_scheduler`` and ``base_learning rate``""" if not isinstance(scheduler, Scheduler): raise TypeError( f"{type(scheduler)} is not a {Scheduler} instance, perhaps missing arguments?" ) if not hasattr(scheduler, "lr"): raise TypeError(f"{scheduler} is not a valid learning rate scheduler") self["lr_scheduler"] = scheduler self["base_learning_rate"] = scheduler.lr if hasattr(scheduler, "max_lr"): self["max_learning_rate"] = scheduler.max_lr else: self["max_learning_rate"] = None if hasattr(scheduler, "cycle_steps"): self["clr_cycle_steps"] = scheduler.cycle_steps else: self["clr_cycle_steps"] = None if hasattr(scheduler, "gamma"): self["clr_gamma"] = scheduler.gamma else: self["clr_gamma"] = None
[docs] def check_values(self): """Checks validity of hyperparameter values""" assert isinstance(self["seed"], (int, np.int64)) assert isinstance(self["batch_shuffle"], bool) assert (self["clr_gamma"] is None) or (isinstance(self["clr_gamma"], float))