File: distributed_image_generation.py

package info (click to toggle)
accelerate 1.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,900 kB
  • sloc: python: 40,061; sh: 90; makefile: 79
file content (117 lines) | stat: -rw-r--r-- 3,820 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Originally by jiwooya1000, put together together by sayakpaul.
Documentation: https://huggingface.co/docs/diffusers/main/en/training/distributed_inference

Run:

accelerate launch distributed_image_generation.py --batch_size 8

# Enable memory optimizations for large models like SD3
accelerate launch distributed_image_generation.py --batch_size 8 --low_mem
"""

import os
import time

import fire
import torch
from datasets import load_dataset
from diffusers import DiffusionPipeline
from tqdm import tqdm

from accelerate import PartialState
from accelerate.utils import gather_object


START_TIME = time.strftime("%Y%m%d_%H%M%S")
DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}


def get_batches(items, batch_size):
    num_batches = (len(items) + batch_size - 1) // batch_size
    batches = []

    for i in range(num_batches):
        start_index = i * batch_size
        end_index = min((i + 1) * batch_size, len(items))
        batch = items[start_index:end_index]
        batches.append(batch)

    return batches


def main(
    ckpt_id: str = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    save_dir: str = "./evaluation/examples",
    seed: int = 1,
    batch_size: int = 4,
    num_inference_steps: int = 20,
    guidance_scale: float = 4.5,
    dtype: str = "fp16",
    low_mem: bool = False,
):
    pipeline = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=DTYPE_MAP[dtype])

    save_dir = save_dir + f"_{START_TIME}"

    parti_prompts = load_dataset("nateraw/parti-prompts", split="train")
    data_loader = get_batches(items=parti_prompts["Prompt"], batch_size=batch_size)

    distributed_state = PartialState()
    if low_mem:
        pipeline.enable_model_cpu_offload(gpu_id=distributed_state.device.index)
    else:
        pipeline = pipeline.to(distributed_state.device)

    if distributed_state.is_main_process:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            print(f"Directory '{save_dir}' created successfully.")
        else:
            print(f"Directory '{save_dir}' already exists.")

    count = 0
    for _, prompts_raw in tqdm(enumerate(data_loader), total=len(data_loader)):
        input_prompts = []

        with distributed_state.split_between_processes(prompts_raw) as prompts:
            generator = torch.manual_seed(seed)
            images = pipeline(
                prompts, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator
            ).images
            input_prompts.extend(prompts)

        distributed_state.wait_for_everyone()

        images = gather_object(images)
        input_prompts = gather_object(input_prompts)

        if distributed_state.is_main_process:
            for image, prompt in zip(images, input_prompts):
                count += 1
                temp_dir = os.path.join(save_dir, f"example_{count}")

                os.makedirs(temp_dir)
                prompt = "_".join(prompt.split())
                image.save(f"image_{prompt}.png")

    if distributed_state.is_main_process:
        print(f">>> Image Generation Finished. Saved in {save_dir}")


if __name__ == "__main__":
    fire.Fire(main)