Skip to content

key_enforcer

KeyEnforcer

Makes sure containers have the specified keys.

Source code in pytorch_adapt\containers\key_enforcer.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class KeyEnforcer:
    """
    Makes sure containers have the specified keys.
    """

    def __init__(self, **kwargs: List[str]):
        """
        Arguments:
            **kwargs: A mapping from container name to a list of required
                keys for that container.
        """
        self.requirements = kwargs

    def check(self, containers: MultipleContainers):
        """
        Compares the input containers' keys to ```self.requirements```.
        Raises ```KeyError``` if there is a mismatch.
        Arguments:
            containers: The containers to check.
        """
        for k, required_keys in self.requirements.items():
            container_keys = list(containers[k].keys())
            r_c_diff = c_f.list_diff(required_keys, container_keys)
            c_r_diff = c_f.list_diff(container_keys, required_keys)
            error_msg = ""
            if len(r_c_diff) > 0:
                error_msg += (
                    f"The {k} container is missing the following keys: {r_c_diff}. "
                )

            if len(c_r_diff) > 0:
                error_msg += (
                    f"The {k} container has the following unallowed keys: {c_r_diff}."
                )

            if error_msg != "":
                raise KeyError(error_msg)

__init__(**kwargs)

Parameters:

Name Type Description Default
**kwargs List[str]

A mapping from container name to a list of required keys for that container.

{}
Source code in pytorch_adapt\containers\key_enforcer.py
12
13
14
15
16
17
18
def __init__(self, **kwargs: List[str]):
    """
    Arguments:
        **kwargs: A mapping from container name to a list of required
            keys for that container.
    """
    self.requirements = kwargs

check(containers)

Compares the input containers' keys to self.requirements. Raises KeyError if there is a mismatch.

Parameters:

Name Type Description Default
containers MultipleContainers

The containers to check.

required
Source code in pytorch_adapt\containers\key_enforcer.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def check(self, containers: MultipleContainers):
    """
    Compares the input containers' keys to ```self.requirements```.
    Raises ```KeyError``` if there is a mismatch.
    Arguments:
        containers: The containers to check.
    """
    for k, required_keys in self.requirements.items():
        container_keys = list(containers[k].keys())
        r_c_diff = c_f.list_diff(required_keys, container_keys)
        c_r_diff = c_f.list_diff(container_keys, required_keys)
        error_msg = ""
        if len(r_c_diff) > 0:
            error_msg += (
                f"The {k} container is missing the following keys: {r_c_diff}. "
            )

        if len(c_r_diff) > 0:
            error_msg += (
                f"The {k} container has the following unallowed keys: {c_r_diff}."
            )

        if error_msg != "":
            raise KeyError(error_msg)