diff --git a/utils.py b/utils.py index 328ad09d..9a994328 100644 --- a/utils.py +++ b/utils.py @@ -89,7 +89,7 @@ def load_pretrained(config, model, logger): absolute_pos_embed_current = model.state_dict()[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() - if C1 != C1: + if C1 != C2: logger.warning(f"Error in loading {k}, passing......") else: if L1 != L2: