Source code for gt4sd.frameworks.gflownet.util

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import queue
import threading

import torch
import torch.multiprocessing as mp
import torch.nn as nn


[docs]class MPModelPlaceholder: """This class can be used as a Model in a worker process, and translates calls to queries to the main process. """
[docs] def __init__(self, in_queues, out_queues): self.qs = in_queues, out_queues self.device = torch.device("cpu") self._is_init = False
[docs] def _check_init(self): if self._is_init: return info = torch.utils.data.get_worker_info() self.in_queue = self.qs[0][info.id] self.out_queue = self.qs[1][info.id] self._is_init = True
[docs] def log_z(self, *a): self._check_init() self.in_queue.put(("log_z", *a)) return self.out_queue.get()
[docs] def __call__(self, *a): self._check_init() self.in_queue.put(("__call__", *a)) return self.out_queue.get()
[docs]class MPModelProxy: """This class maintains a reference to an in-cuda-memory model, and creates a `placeholder` attribute which can be safely passed to multiprocessing DataLoader workers. This placeholder model sends messages accross multiprocessing queues, which are received by this proxy instance, which calls the model and sends the return value back to the worker. Starts its own (daemon) thread. Always passes CPU tensors between processes. """
[docs] def __init__(self, model: torch.nn.Module, num_workers: int, cast_types: tuple): """Construct a multiprocessing model proxy for torch DataLoaders. Args: model: a torch model which lives in the main process to which method calls are passed. num_workers: number of workers. cast_types: types that will be cast to cuda when received as arguments of method calls. """ self.in_queues = [mp.Queue() for i in range(num_workers)] # type: ignore self.out_queues = [mp.Queue() for i in range(num_workers)] # type: ignore self.placeholder = MPModelPlaceholder(self.in_queues, self.out_queues) self.model = model self.device = next(model.parameters()).device self.cuda_types = (torch.Tensor,) + cast_types self.stop = threading.Event() self.thread = threading.Thread(target=self.run, daemon=True) self.thread.start()
[docs] def __del__(self): self.stop.set()
[docs] def run(self): while not self.stop.is_set(): for qi, q in enumerate(self.in_queues): try: r = q.get(True, 1e-5) except queue.Empty: continue except ConnectionError: break attr, *args = r f = getattr(self.model, attr) args = [ i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args ] result = f(*args) if isinstance(result, (list, tuple)): msg = [ i.detach().to(torch.device("cpu")) if isinstance(i, self.cuda_types) else i for i in result ] self.out_queues[qi].put(msg) else: msg = ( result.detach().to(torch.device("cpu")) if isinstance(result, self.cuda_types) else result ) self.out_queues[qi].put(msg)
[docs]def wrap_model_mp( model: nn.Module, num_workers: int, cast_types: tuple ) -> MPModelPlaceholder: """Construct a multiprocessing model proxy for torch DataLoaders so that only one process ends up making cuda calls and holding cuda tensors in memory. Args: model: a torch model which lives in the main process to which method calls are passed. num_workers: number of DataLoader workers. cast_types: types that will be cast to cuda when received as arguments of method calls. torch.Tensor is cast by default. Returns: placeholder: a placeholder model whose method calls route arguments to the main process. """ return MPModelProxy(model, num_workers, cast_types).placeholder