Merge branch 'master' into hooks_part2

This commit is contained in:
Jedrzej Kosinski
2025-01-07 01:01:53 -06:00
24 changed files with 137 additions and 94 deletions

View File

@@ -402,7 +402,20 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def get_model_object(self, name):
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
Args:
name (str): The attribute path using dot notation (e.g. "model.layer.weight")
Returns:
The value of the requested attribute
Example:
patcher = ModelPatcher()
weight = patcher.get_model_object("layer1.conv.weight")
"""
if name in self.object_patches:
return self.object_patches[name]
else: