Source code for gt4sd.training_pipelines.torchdrug.unpatch

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import os
import shutil
import subprocess
import tempfile
import typing
from subprocess import CalledProcessError
from typing import List

import importlib_metadata
import torch
from packaging import version
from torch.optim.lr_scheduler import (  # type: ignore
    ChainedScheduler,
    ConstantLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    CyclicLR,
    ExponentialLR,
    LambdaLR,
    LinearLR,
    MultiplicativeLR,
    MultiStepLR,
    OneCycleLR,
    SequentialLR,
    StepLR,
    _LRScheduler,
)
from torch.utils.cpp_extension import IS_WINDOWS
from torch.utils.data.dataset import (
    ChainDataset,
    ConcatDataset,
    Dataset,
    IterableDataset,
    Subset,
    TensorDataset,
)

sane_datasets = [
    Dataset,
    ChainDataset,
    ConcatDataset,
    IterableDataset,
    Subset,
    TensorDataset,
]

torch_version = version.parse(importlib_metadata.version("torch"))

if torch_version < version.parse("1.12") and torch_version >= version.parse("1.10"):
    from torch.utils.data.dataset import DFIterDataPipe  # type: ignore

    sane_datasets.append(DFIterDataPipe)

if torch_version < version.parse("1.12") and torch_version >= version.parse("1.11"):
    from torch.utils.data.dataset import IterDataPipe, MapDataPipe  # type: ignore

    sane_datasets.extend([IterDataPipe, MapDataPipe])


sane_schedulers = [
    _LRScheduler,
    ChainedScheduler,
    ConstantLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    CyclicLR,
    ExponentialLR,
    LambdaLR,
    LinearLR,
    MultiStepLR,
    MultiplicativeLR,
    OneCycleLR,
    SequentialLR,
    StepLR,
]


[docs]@typing.no_type_check def fix_datasets(sane_datasets: List[Dataset]) -> None: """ Helper function to revert TorchDrug dataset handling (which breaks core pytorch functionalities). For details see: https://github.com/DeepGraphLearning/torchdrug/issues/96 Args: sane_datasets: A list of pytorch datasets. Raises: AttributeError: If a passed dataset was not sane. """ dataset = sane_datasets[0] torch.utils.data.dataset.Dataset = dataset # type: ignore torch.utils.data.dataset.ChainDataset = sane_datasets[1] # type: ignore torch.utils.data.dataset.ConcatDataset = sane_datasets[2] # type: ignore torch.utils.data.dataset.IterableDataset = sane_datasets[3] # type: ignore torch.utils.data.dataset.Subset = sane_datasets[4] # type: ignore torch.utils.data.dataset.TensorDataset = sane_datasets[5] # type: ignore if torch_version < version.parse("1.12") and torch_version >= version.parse("1.10"): torch.utils.data.dataset.DFIterDataPipe = sane_datasets[6] # type: ignore if torch_version < version.parse("1.12") and torch_version >= version.parse("1.11"): torch.utils.data.dataset.IterDataPipe = sane_datasets[7] # type: ignore torch.utils.data.dataset.MapDataPipe = sane_datasets[8] # type: ignore for ds in sane_datasets[1:]: if not issubclass(ds, dataset): raise AttributeError( f"Reverting silent TorchDrug overwriting failed, {ds} is not a subclass" f" of {dataset}." )
[docs]@typing.no_type_check def fix_schedulers(sane_schedulers: List[_LRScheduler]) -> None: """ Helper function to revert TorchDrug LR scheduler handling (which breaks core pytorch functionalities). For details see: https://github.com/DeepGraphLearning/torchdrug/issues/96 Args: sane_schedulers: A list of pytorch lr_schedulers. Raises: AttributeError: If a passed lr_scheduler was not sane. """ scheduler = sane_schedulers[0] torch.optim.lr_scheduler._LRScheduler = scheduler # type: ignore torch.optim.lr_scheduler.ChainedScheduler = sane_schedulers[1] # type: ignore torch.optim.lr_scheduler.ConstantLR = sane_schedulers[2] # type: ignore torch.optim.lr_scheduler.CosineAnnealingLR = sane_schedulers[3] # type: ignore torch.optim.lr_scheduler.CosineAnnealingWarmRestarts = sane_schedulers[4] # type: ignore torch.optim.lr_scheduler.CyclicLR = sane_schedulers[5] # type: ignore torch.optim.lr_scheduler.ExponentialLR = sane_schedulers[6] # type: ignore torch.optim.lr_scheduler.LambdaLR = sane_schedulers[7] # type: ignore torch.optim.lr_scheduler.LinearLR = sane_schedulers[8] # type: ignore torch.optim.lr_scheduler.MultiStepLR = sane_schedulers[9] # type: ignore torch.optim.lr_scheduler.MultiplicativeLR = sane_schedulers[10] # type: ignore torch.optim.lr_scheduler.OneCycleLR = sane_schedulers[11] # type: ignore torch.optim.lr_scheduler.SequentialLR = sane_schedulers[12] # type: ignore torch.optim.lr_scheduler.StepLR = sane_schedulers[13] # type: ignore for lrs in sane_schedulers[1:]: if not issubclass(lrs, scheduler): raise AttributeError( f"Reverting silent TorchDrug overwriting failed, {lrs} is not a subclass" f" of {scheduler}." )
CHECK_CODE = "int main(){return 0;}"
[docs]def check_openmp_availabilty() -> bool: """ Check if OpenMP is available at runtime. Returns: True if OpenMP is available, False otherwise. """ if IS_WINDOWS: compiler = os.environ.get("CXX", "cl") else: compiler = os.environ.get("CXX", "c++") tempfolder = tempfile.mkdtemp() with open(os.path.join(tempfolder, "main.cpp"), "w") as f: f.write(CHECK_CODE) is_openmp_available = True try: subprocess.check_call( [ compiler, "-fopenmp", f"{tempfolder}/main.cpp", "-o", f"{tempfolder}/main.o", ] ) except CalledProcessError: is_openmp_available = False finally: shutil.rmtree(tempfolder) return is_openmp_available
TORCH_HAS_OPENMP = torch._C.has_openmp # NOTE: ensuring this variable indicates OpenMP availbility in current compilerß torch._C.has_openmp = check_openmp_availabilty()