#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import os
import shutil
import subprocess
import tempfile
import typing
from subprocess import CalledProcessError
from typing import List
import importlib_metadata
import torch
from packaging import version
from torch.optim.lr_scheduler import (  # type: ignore
    ChainedScheduler,
    ConstantLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    CyclicLR,
    ExponentialLR,
    LambdaLR,
    LinearLR,
    MultiplicativeLR,
    MultiStepLR,
    OneCycleLR,
    SequentialLR,
    StepLR,
    _LRScheduler,
)
from torch.utils.cpp_extension import IS_WINDOWS
from torch.utils.data.dataset import (
    ChainDataset,
    ConcatDataset,
    Dataset,
    IterableDataset,
    Subset,
    TensorDataset,
)
sane_datasets = [
    Dataset,
    ChainDataset,
    ConcatDataset,
    IterableDataset,
    Subset,
    TensorDataset,
]
torch_version = version.parse(importlib_metadata.version("torch"))
if torch_version < version.parse("1.12") and torch_version >= version.parse("1.10"):
    from torch.utils.data.dataset import DFIterDataPipe  # type: ignore
    sane_datasets.append(DFIterDataPipe)
if torch_version < version.parse("1.12") and torch_version >= version.parse("1.11"):
    from torch.utils.data.dataset import IterDataPipe, MapDataPipe  # type: ignore
    sane_datasets.extend([IterDataPipe, MapDataPipe])
sane_schedulers = [
    _LRScheduler,
    ChainedScheduler,
    ConstantLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    CyclicLR,
    ExponentialLR,
    LambdaLR,
    LinearLR,
    MultiStepLR,
    MultiplicativeLR,
    OneCycleLR,
    SequentialLR,
    StepLR,
]
[docs]@typing.no_type_check
def fix_datasets(sane_datasets: List[Dataset]) -> None:
    """
    Helper function to revert TorchDrug dataset handling (which breaks core
    pytorch functionalities). For details see:
    https://github.com/DeepGraphLearning/torchdrug/issues/96
    Args:
        sane_datasets: A list of pytorch datasets.
    Raises:
        AttributeError: If a passed dataset was not sane.
    """
    dataset = sane_datasets[0]
    torch.utils.data.dataset.Dataset = dataset  # type: ignore
    torch.utils.data.dataset.ChainDataset = sane_datasets[1]  # type: ignore
    torch.utils.data.dataset.ConcatDataset = sane_datasets[2]  # type: ignore
    torch.utils.data.dataset.IterableDataset = sane_datasets[3]  # type: ignore
    torch.utils.data.dataset.Subset = sane_datasets[4]  # type: ignore
    torch.utils.data.dataset.TensorDataset = sane_datasets[5]  # type: ignore
    if torch_version < version.parse("1.12") and torch_version >= version.parse("1.10"):
        torch.utils.data.dataset.DFIterDataPipe = sane_datasets[6]  # type: ignore
    if torch_version < version.parse("1.12") and torch_version >= version.parse("1.11"):
        torch.utils.data.dataset.IterDataPipe = sane_datasets[7]  # type: ignore
        torch.utils.data.dataset.MapDataPipe = sane_datasets[8]  # type: ignore
    for ds in sane_datasets[1:]:
        if not issubclass(ds, dataset):
            raise AttributeError(
                f"Reverting silent TorchDrug overwriting failed, {ds} is not a subclass"
                f" of {dataset}."
            ) 
[docs]@typing.no_type_check
def fix_schedulers(sane_schedulers: List[_LRScheduler]) -> None:
    """
    Helper function to revert TorchDrug LR scheduler handling (which breaks core
    pytorch functionalities). For details see:
    https://github.com/DeepGraphLearning/torchdrug/issues/96
    Args:
        sane_schedulers: A list of pytorch lr_schedulers.
    Raises:
        AttributeError: If a passed lr_scheduler was not sane.
    """
    scheduler = sane_schedulers[0]
    torch.optim.lr_scheduler._LRScheduler = scheduler  # type: ignore
    torch.optim.lr_scheduler.ChainedScheduler = sane_schedulers[1]  # type: ignore
    torch.optim.lr_scheduler.ConstantLR = sane_schedulers[2]  # type: ignore
    torch.optim.lr_scheduler.CosineAnnealingLR = sane_schedulers[3]  # type: ignore
    torch.optim.lr_scheduler.CosineAnnealingWarmRestarts = sane_schedulers[4]  # type: ignore
    torch.optim.lr_scheduler.CyclicLR = sane_schedulers[5]  # type: ignore
    torch.optim.lr_scheduler.ExponentialLR = sane_schedulers[6]  # type: ignore
    torch.optim.lr_scheduler.LambdaLR = sane_schedulers[7]  # type: ignore
    torch.optim.lr_scheduler.LinearLR = sane_schedulers[8]  # type: ignore
    torch.optim.lr_scheduler.MultiStepLR = sane_schedulers[9]  # type: ignore
    torch.optim.lr_scheduler.MultiplicativeLR = sane_schedulers[10]  # type: ignore
    torch.optim.lr_scheduler.OneCycleLR = sane_schedulers[11]  # type: ignore
    torch.optim.lr_scheduler.SequentialLR = sane_schedulers[12]  # type: ignore
    torch.optim.lr_scheduler.StepLR = sane_schedulers[13]  # type: ignore
    for lrs in sane_schedulers[1:]:
        if not issubclass(lrs, scheduler):
            raise AttributeError(
                f"Reverting silent TorchDrug overwriting failed, {lrs} is not a subclass"
                f" of {scheduler}."
            ) 
CHECK_CODE = "int main(){return 0;}"
[docs]def check_openmp_availabilty() -> bool:
    """
    Check if OpenMP is available at runtime.
    Returns:
        True if OpenMP is available, False otherwise.
    """
    if IS_WINDOWS:
        compiler = os.environ.get("CXX", "cl")
    else:
        compiler = os.environ.get("CXX", "c++")
    tempfolder = tempfile.mkdtemp()
    with open(os.path.join(tempfolder, "main.cpp"), "w") as f:
        f.write(CHECK_CODE)
    is_openmp_available = True
    try:
        subprocess.check_call(
            [
                compiler,
                "-fopenmp",
                f"{tempfolder}/main.cpp",
                "-o",
                f"{tempfolder}/main.o",
            ]
        )
    except CalledProcessError:
        is_openmp_available = False
    finally:
        shutil.rmtree(tempfolder)
    return is_openmp_available 
TORCH_HAS_OPENMP = torch._C.has_openmp
# NOTE: ensuring this variable indicates OpenMP availbility in current compilerß
torch._C.has_openmp = check_openmp_availabilty()