Source code for libai.optim.build

# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import inspect
from collections import defaultdict
from typing import Any, Dict, List

import oneflow as flow

from libai.config import instantiate
from libai.layers import LayerNorm
from libai.utils.registry import Registry

OPTIMIZER_REGISTRY = Registry("Optimizer")
OPTIMIZER_REGISTRY.__doc__ = """
Registry for optimizer, i.e. SGD, AdamW

The registered object will be called with `obj(cfg)`
and expected to return a `flow.optim.Optimizer` object.
"""


def register_optimizer():
    flow_optimizers = []
    for module_name in dir(flow.optim):
        if module_name.startswith("__"):
            continue
        _optim = getattr(flow.optim, module_name)
        if inspect.isclass(_optim) and issubclass(_optim, flow.optim.Optimizer):
            OPTIMIZER_REGISTRY.register(_optim)
            flow_optimizers.append(module_name)
    return flow_optimizers


FLOW_OPTIMIZERS = register_optimizer()


[docs]def build_optimizer(cfg, model): """ Build an optimizer from config. """ if "_target_" in cfg: cfg.params.model = model optim = instantiate(cfg) else: optim_name = cfg.optim_name optim = OPTIMIZER_REGISTRY.get(optim_name)( get_default_optimizer_params(model, **cfg.param_cfg), **cfg.optim_cfg ) return optim
[docs]def get_default_optimizer_params( model, base_lr=None, weight_decay=None, weight_decay_norm=None, weight_decay_bias=None, clip_grad_max_norm=None, clip_grad_norm_type=None, overrides=None, ): """ Get default param list for optimizer, with suport for a few types of overrides. If no overrides needed, this is equivalent to `model.parameters()`. Arguments: base_lr: lr for every group by default. Can be omitted to use the one in optimizer. weight_decay: weight decay for every group by default. Can be omitted to use the one in optimizer. weight_decay_norm: override weight decay for params in normalization layers weight_decay_bias: override weight decay for bias parameters overrides: if not `None`, provides values for optimizer hyperparameters (LR, weight decay) for module parameters with a given name; e.g. ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and weight decay values for all module parameters named `embedding`. For common transformer models, ``weight_decay_norm`` and ``weight_decay_bias`` are usually set to 0. Example: :: flow.optim.AdamW( get_default_optimizer_params(model, weight_decay_norm=0, weight_decay_bias=0), lr=0.01, weight_decay=1e-4 ) """ if overrides is None: overrides = {} defaults = {} if base_lr is not None: defaults["lr"] = base_lr if weight_decay is not None: defaults["weight_decay"] = weight_decay if clip_grad_max_norm is not None and clip_grad_norm_type is not None: defaults["clip_grad_max_norm"] = clip_grad_max_norm defaults["clip_grad_norm_type"] = clip_grad_norm_type bias_overrides = {} if weight_decay_bias is not None: bias_overrides["weight_decay"] = weight_decay_bias if len(bias_overrides): if "bias" in overrides: raise ValueError("Conflicting overrides for 'bias'") overrides["bias"] = bias_overrides norm_module_types = ( LayerNorm, flow.nn.BatchNorm1d, flow.nn.BatchNorm2d, flow.nn.BatchNorm3d, flow.nn.GroupNorm, flow.nn.InstanceNorm1d, flow.nn.InstanceNorm2d, flow.nn.InstanceNorm3d, flow.nn.FusedBatchNorm1d, flow.nn.FusedBatchNorm2d, flow.nn.FusedBatchNorm3d, ) params = [] memo = set() for module in model.modules(): for model_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) hyperparams = copy.copy(defaults) if isinstance(module, norm_module_types) and weight_decay_norm is not None: hyperparams["weight_decay"] = weight_decay_norm hyperparams.update(overrides.get(model_param_name, {})) params.append({"params": [value], **hyperparams}) return reduce_param_groups(params)
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Transform parameter groups into per-parameter structure. Later items in `params` can overwrite parameters set in previous items. """ ret = defaultdict(dict) for item in params: assert "params" in item cur_params = {x: y for x, y in item.items() if x != "params"} for param in item["params"]: ret[param].update({"params": [param], **cur_params}) return list(ret.values()) def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Reorganize the parameter groups and merge duplicated groups. The number of parameter groups needs to be as small as possible in order to efficiently use the OneFlow multi-tensor optimizer. Therefore instead of using a parameter_group per single parameter, we reorganize the parameter groups and merge duplicated groups. This approach speeds up multi-tensor optimizer significantly. """ params = _expand_param_groups(params) groups = defaultdict(list) # re-group all parameter groups by their hyperparams for item in params: cur_params = tuple((x, y) for x, y in item.items() if x != "params") groups[cur_params].extend(item["params"]) ret = [] for param_keys, param_values in groups.items(): cur = {kv[0]: kv[1] for kv in param_keys} cur["params"] = param_values ret.append(cur) return ret