add gradient checkpointing for the final_layernorm module.

#77
by zhaoqf123 - opened

Without this, when tuning with LoRA + gradient checkpointing, the last transformer layer, i.e., layer-27's LoRA weights won't be updated!

For example, if we use this callback to log the weight change of LoRA weights in each layer, we will find that no weight update for the last layer in TensorBoard.

class ParamsTensorBoardCallback(TensorBoardCallback):
    def __init__(self, tb_writer=None, params=None, process_name=lambda x:x):
        super().__init__(tb_writer)
        self.params = params
        self._process_name = process_name

    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step % args.logging_steps == 0:
            dict_ = {}
            model = kwargs["model"]
            for name in self.params:
                param = model.get_parameter(name)
                param = param.flatten()
                name_p = self._process_name(name)
                dict_tmp = {
                    f"{name_p}_mean": param.mean().item(),
                    f"{name_p}_max": param.max().item(),
                    f"{name_p}_q75": param.quantile(0.75).item(),
                    f"{name_p}_q25": param.quantile(0.25).item(),
                    f"{name_p}_min": param.min().item(),
                    f"{name_p}_median": param.median().item(),
                    f"{name_p}_std": param.std().item(),
                }
                dict_.update(dict_tmp)
            self.on_log(args, state, control, logs=dict_, **kwargs)

def get_params_for_logging(model):
    ls_params = []
    for name, param in model.named_parameters():
        if param.requires_grad:
            ls_params.append(name)
    return ls_params

ls_params = get_params_for_logging(model)
tb_cb = ParamsTensorBoardCallback(
    None, ls_params, process_name=lambda x: x[36:]
)

trainer = Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=args,
        data_collator=data_collator,
        callbacks=[tb_cb]
    )

I have made a similar PR for llama model in transformer repo.

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment