Source code for gt4sd.training_pipelines.diffusion.core

#
# MIT License
#
# Copyright (c) 2022 GT4SD team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Diffusion training utilities. Code adapted from: https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py"""

import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict

import torch
import torch.nn.functional as F
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from rdkit.Chem import Descriptors
from torchvision.transforms import (
    CenterCrop,
    Compose,
    InterpolationMode,
    Normalize,
    RandomHorizontalFlip,
    Resize,
    ToTensor,
)
from tqdm.auto import tqdm

from ..core import TrainingPipeline, TrainingPipelineArguments

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

# NOTE: This import is needed because importing torchvision before rdkit Descriptors
# can cause segmentation faults.
Descriptors


[docs]class DiffusionForVisionTrainingPipeline(TrainingPipeline): """Diffusion training pipelines for image generation."""
[docs] def train( # type: ignore self, training_args: Dict[str, Any], model_args: Dict[str, Any], dataset_args: Dict[str, Any], ) -> None: """Generic training function for Diffusion models. Args: training_args: training arguments passed to the configuration. model_args: model arguments passed to the configuration. dataset_args: dataset arguments passed to the configuration. Raises: NotImplementedError: the generic trainer does not implement the pipeline. """ params = {**training_args, **dataset_args, **model_args} env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != params["local_rank"]: params["local_rank"] = env_local_rank logging_dir = os.path.join(params["output_dir"], params["logging_dir"]) model_path = params["model_path"] training_name = params["training_name"] accelerator = Accelerator( mixed_precision=params["mixed_precision"], log_with="tensorboard", logging_dir=logging_dir, ) logger.info(f"Model with name {training_name} starts.") model_dir = os.path.join(model_path, training_name) log_path = os.path.join(model_dir, "logs") val_dir = os.path.join(log_path, "val_logs") os.makedirs(os.path.join(model_dir, "weights"), exist_ok=True) os.makedirs(os.path.join(model_dir, "results"), exist_ok=True) os.makedirs(log_path, exist_ok=True) os.makedirs(val_dir, exist_ok=True) # unet decoder model = UNet2DModel( sample_size=params["resolution"], in_channels=params["in_channels"], out_channels=params["out_channels"], layers_per_block=params["layers_per_block"], block_out_channels=(128, 128, 256, 256, 512, 512), down_block_types=( "DownBlock2D", "DownBlock2D", "DownBlock2D", "DownBlock2D", "AttnDownBlock2D", "DownBlock2D", ), up_block_types=( "UpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", "UpBlock2D", ), ) # ddpm noise schedule noise_scheduler = DDPMScheduler( num_train_timesteps=params["num_train_timesteps"] ) optimizer = torch.optim.AdamW( model.parameters(), lr=params["learning_rate"], betas=(params["adam_beta1"], params["adam_beta2"]), weight_decay=params["adam_weight_decay"], eps=params["adam_epsilon"], ) augmentations = Compose( [ Resize(params["resolution"], interpolation=InterpolationMode.BILINEAR), CenterCrop(params["resolution"]), RandomHorizontalFlip(), ToTensor(), Normalize([0.5], [0.5]), ] ) # specify dataset by name or path if params["dataset_name"] is not None: dataset = load_dataset( params["dataset_name"], params["dataset_config_name"], cache_dir=params["cache_dir"], use_auth_token=True if params["use_auth_token"] else None, split="train", ) logger.info("dataset name: " + params["dataset_name"]) else: if params["train_data_dir"] is None: raise ValueError( "You must specify either a dataset name from the hub or a train data directory." ) dataset = load_dataset( "imagefolder", data_dir=params["train_data_dir"], cache_dir=params["cache_dir"], split="train", ) logger.info("dataset path: " + params["train_data_dir"]) def transforms(examples): try: images = [ augmentations(image.convert("RGB")) for image in examples["img"] ] except KeyError: images = [ augmentations(image.convert("RGB")) for image in examples["image"] ] return {"input": images} dataset.set_transform(transforms) # type: ignore train_dataloader = torch.utils.data.DataLoader( # type: ignore dataset, batch_size=params["train_batch_size"], shuffle=True # type: ignore ) # specify learning rate scheduler lr_scheduler = get_scheduler( params["lr_scheduler"], optimizer=optimizer, num_warmup_steps=params["lr_warmup_steps"], num_training_steps=(len(train_dataloader) * params["num_epochs"]) // params["gradient_accumulation_steps"], ) # preparare for distributed training if neeeded model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) # initialize the ema model ema_model = EMAModel( model, inv_gamma=params["ema_inv_gamma"], power=params["ema_power"], max_value=params["ema_max_decay"], ) if accelerator.is_main_process: run = os.path.split(__file__)[-1].split(".")[0] accelerator.init_trackers(run) global_step = 0 # start training for epoch in range(params["num_epochs"]): model.train() # progress bar visualization progress_bar = tqdm( total=len(train_dataloader), disable=not accelerator.is_local_main_process, ) progress_bar.set_description(f"Epoch {epoch}") for _, batch in enumerate(train_dataloader): clean_images = batch["input"] # Sample noise that we'll add to the images noise = torch.randn(clean_images.shape).to(clean_images.device) bsz = clean_images.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device, ).long() # Add noise to the clean images according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) with accelerator.accumulate(model): # Predict the noise residual noise_pred = model(noisy_images, timesteps)["sample"] loss = F.mse_loss(noise_pred, noise) accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() if params["use_ema"]: ema_model.step(model) optimizer.zero_grad() progress_bar.update(1) logs = { "loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step, } if params["use_ema"]: logs["ema_decay"] = ema_model.decay progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) global_step += 1 if params["dummy_training"]: break progress_bar.close() # wait for all the processes to finish accelerator.wait_for_everyone() # Generate sample images for visual inspection if accelerator.is_main_process and params["is_sampling"]: if ( epoch % params["save_images_epochs"] == 0 or epoch == params["num_epochs"] - 1 ): # inference/sampling pipeline = DDPMPipeline( unet=accelerator.unwrap_model( ema_model.averaged_model if params["use_ema"] else model ), scheduler=noise_scheduler, ) generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) images = pipeline( generator=generator, batch_size=params["eval_batch_size"], output_type="numpy", )["sample"] # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") accelerator.trackers[0].writer.add_images( "test_samples", images_processed.transpose(0, 3, 1, 2), epoch, ) if ( epoch % params["save_model_epochs"] == 0 or epoch == params["num_epochs"] - 1 ): pipeline.save_pretrained(params["output_dir"]) accelerator.wait_for_everyone() if params["dummy_training"]: break accelerator.end_training() logger.info("Training done, shutting down.")
[docs]@dataclass class DiffusionDataArguments(TrainingPipelineArguments): """Data arguments related to diffusion trainer.""" __name__ = "dataset_args" dataset_name: str = field(default="", metadata={"help": "Dataset name."}) dataset_config_name: str = field( default="", metadata={"help": "Dataset config name."} ) train_data_dir: str = field(default="", metadata={"help": "Train data directory."}) resolution: int = field(default=64, metadata={"help": "Resolution."}) train_batch_size: int = field(default=16, metadata={"help": "Train batch size."}) eval_batch_size: int = field(default=16, metadata={"help": "Eval batch size."}) num_epochs: int = field(default=100, metadata={"help": "Number of epochs."})
[docs]@dataclass class DiffusionModelArguments(TrainingPipelineArguments): """Model arguments related to Diffusion trainer.""" __name__ = "model_args" model_path: str = field(metadata={"help": "Path to the model file."}) training_name: str = field(metadata={"help": "Name of the training run."}) num_train_timesteps: int = field( default=1000, metadata={"help": "Number of noise steps."} ) learning_rate: float = field(default=1e-4, metadata={"help": "Learning rate."}) lr_scheduler: str = field( default="cosine", metadata={"help": "Learning rate scheduler."} ) lr_warm_up_steps: int = field( default=500, metadata={"help": "Learning rate warm up steps."} ) adam_beta1: float = field(default=0.95, metadata={"help": "Adam beta1."}) adam_beta2: float = field(default=0.999, metadata={"help": "Adam beta2."}) adam_weight_decay: float = field( default=1e-6, metadata={"help": "Adam weights decay."} ) adam_epsilon: float = field(default=1e-8, metadata={"help": "Adam eps."}) gradient_accumulation_steps: int = field( default=1, metadata={"help": "Gradient accumulation steps."} ) in_channels: int = field(default=3, metadata={"help": "Input channels."}) out_channels: int = field(default=3, metadata={"help": "Output channels."}) layers_per_block: int = field(default=2, metadata={"help": "Layers per block."})
[docs]@dataclass class DiffusionTrainingArguments(TrainingPipelineArguments): """Training arguments related to Diffusion trainer.""" __name__ = "training_args" local_rank: int = field(default=-1, metadata={"help": "Local rank of the process."}) output_dir: str = field( default="ddpm-cifar10-32", metadata={"help": "Output directory."} ) logging_dir: str = field(default="logs/", metadata={"help": "Logging directory."}) overwrite_output_dir: bool = field( default=False, metadata={"help": "Overwrite output directory."} ) cache_dir: str = field(default=".cache/", metadata={"help": "Cache directory."}) save_images_epochs: int = field( default=10, metadata={"help": "Save images every n epochs."} ) save_model_epochs: int = field( default=10, metadata={"help": "Save model every n epochs."} ) use_ema: bool = field(default=True, metadata={"help": "Use ema."}) ema_inv_gamma: float = field(default=1.0, metadata={"help": "Ema inverse gamma."}) ema_power: float = field(default=0.75, metadata={"help": "Ema power."}) ema_max_decay: float = field(default=0.9999, metadata={"help": "Ema max delay."}) mixed_precision: str = field( default="no", metadata={"help": "Mixed precision. Choose from 'no', 'fp16', 'bf16'."}, ) use_auth_token: bool = field( default=False, metadata={ "help": "Use the token generated when using huggingface-hub (necessary to use this script with private models)." }, ) dummy_training: bool = field( default=False, metadata={"help": "Run dummy training to test the pipeline."}, ) is_sampling: bool = field( default=True, metadata={"help": "Run sampling."}, )
[docs]@dataclass class DiffusionSavingArguments(TrainingPipelineArguments): """Saving arguments related to Diffusion trainer.""" __name__ = "saving_args"