1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
import logging
from typing import Dict, List, TYPE_CHECKING
from torch.distributed.checkpoint.planner import SavePlan
if TYPE_CHECKING:
from torch.distributed.checkpoint.metadata import MetadataIndex
__all__ = ["dedup_tensors"]
def init_logger() -> logging.Logger:
logger = logging.getLogger(__name__)
level = logging.INFO
logger.setLevel(level)
console = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
)
console.setFormatter(formatter)
console.setLevel(level)
logger.addHandler(console)
logger.propagate = False
return logger
logger = init_logger()
# TODO add docstring for dedup_tensors
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
all_plans = list(all_plans)
key_to_plan: Dict[MetadataIndex, List[int]] = {}
for plan_idx, plan in enumerate(all_plans):
for write_item in plan.items:
key_to_plan.setdefault(write_item.index, []).append(plan_idx)
replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}
# Remove duplicates by always keeping the first entry.
# Compute the per-rank remove set.
plan_to_keys: Dict[int, List[MetadataIndex]] = {}
for key, plans in replicated_items.items():
for plan_idx in plans[1:]:
plan_to_keys.setdefault(plan_idx, []).append(key)
if len(plan_to_keys) > 0:
logger.info("Duplicate keys to remove: %s", plan_to_keys)
for plan_idx, keys in plan_to_keys.items():
key_set = set(keys)
# rewrite items and remove elements
new_items = [
write_item
for write_item in all_plans[plan_idx].items
if write_item.index not in key_set
]
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans
|