Skip to content

Refactor load_state_dict method to accept state_dict directly#7

Open
JackMcKechnie wants to merge 1 commit into
Parry-Parry:mainfrom
JackMcKechnie:continued-training
Open

Refactor load_state_dict method to accept state_dict directly#7
JackMcKechnie wants to merge 1 commit into
Parry-Parry:mainfrom
JackMcKechnie:continued-training

Conversation

@JackMcKechnie
Copy link
Copy Markdown
Contributor

Refactor load_state_dict method to accept state_dict directly and handle string input.

Refactor load_state_dict method to accept state_dict directly and handle string input.
@JackMcKechnie
Copy link
Copy Markdown
Contributor Author

This seems to work, helping with validation

Copy link
Copy Markdown
Owner

@Parry-Parry Parry-Parry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments

Comment thread rankers/modelling/base.py
return self.model.load_state_dict(
self.architecture_class.from_pretrained(model_dir).state_dict()
)
def load_state_dict(self, state_dict, strict=False, **kwargs):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming here gets a bit weird,

Comment thread rankers/modelling/base.py
self.architecture_class.from_pretrained(model_dir).state_dict()
)
def load_state_dict(self, state_dict, strict=False, **kwargs):
if isinstance(state_dict, str):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it clearer that we are saying this is a valid path

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants