Skip to content

validate

validate_hook(hook, available_keys=None, depth=0, model_counts=None)

Parameters:

Name Type Description Default
hook

the hook to validate

required
available_keys

a list of keys that the context will start with.

None

Returns:

Type Description
Dict[str, int]

A dictionary with each model's forward call count.

Source code in pytorch_adapt\hooks\validate.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def validate_hook(
    hook, available_keys=None, depth=0, model_counts=None
) -> Dict[str, int]:
    """
    Arguments:
        hook: the hook to validate
        available_keys: a list of keys that the context
            will start with.
    Returns:
        A dictionary with each model's ```forward``` call count.
    """
    c_f.LOGGER.debug(f"VALIDATE: {'  '*depth}{c_f.cls_name(hook)}")
    available_keys = c_f.default(available_keys, [])
    model_counts = c_f.default(model_counts, defaultdict(int))

    if isinstance(available_keys, list):
        available_keys = set(available_keys)

    if isinstance(hook, ChainHook):
        hooks = hook.hooks
        for i in range(0, len(hooks)):
            validate_hook(hooks[i], available_keys, depth + 1, model_counts)

    elif isinstance(hook, ParallelHook):
        hooks = hook.hooks
        for i in range(0, len(hooks)):
            curr_available_keys = copy.deepcopy(available_keys)
            validate_hook(hooks[i], curr_available_keys, depth + 1, model_counts)

    elif isinstance(hook, RepeatHook):
        for _ in range(hook.n):
            curr_available_keys = copy.deepcopy(available_keys)
            validate_hook(hook.hook, curr_available_keys, depth + 1, model_counts)

    else:
        check_keys_are_present(
            hook, hook.in_keys, list(available_keys), "in_keys", "available_keys"
        )
        check_keys_are_present(
            hook,
            list(hook.key_map.keys()),
            list(available_keys),
            "key_map",
            "available_keys",
        )
        all_hooks = c_f.attrs_of_type(hook, BaseHook)
        for h in all_hooks.values():
            validate_hook(h, available_keys, depth + 1, model_counts)
        update_model_counts(hook, available_keys, model_counts)
        available_keys.update(set(hook.out_keys))

    return model_counts