diff --git a/capirca/lib/nsxt.py b/capirca/lib/nsxt.py index dc55ddb4..70db1f23 100644 --- a/capirca/lib/nsxt.py +++ b/capirca/lib/nsxt.py @@ -18,12 +18,13 @@ import datetime import json -from typing import Literal, TypedDict, Optional, Union, Tuple +from typing import Optional, Union, Tuple, List from absl import logging from capirca.lib import aclgenerator from capirca.lib import nacaddr from capirca.lib import policy # for typing information +from typing_extensions import Literal, TypedDict # pylint: disable=g-multiple-import _ACTION_TABLE = { 'accept': 'ALLOW', @@ -90,9 +91,9 @@ class NsxtUnsupportedManyPoliciesError(Error): class ServiceEntries: """Represents service entries for a rule.""" - def __init__(self, protocol: int, source_ports: list[Tuple[str, str]], - destination_ports: list[Tuple[str, str]], - icmp_types: list[int]): + def __init__(self, protocol: int, source_ports: List[Tuple[str, str]], + destination_ports: List[Tuple[str, str]], + icmp_types: List[int]): """Setting things up. Args: @@ -239,8 +240,8 @@ def __str__(self): af_list = [self.af] # There can be many source and destination addresses. - source_address: list[nacaddr.IPType] = [] - destination_address: list[nacaddr.IPType] = [] + source_address: List[nacaddr.IPType] = [] + destination_address: List[nacaddr.IPType] = [] source_addr = [] destination_addr = [] @@ -257,19 +258,19 @@ def __str__(self): # cannot be a part of a netblock passed into NSX-T API. Currently only # addressing IPv4 as that's where the issue has been identified. # https://github.com/google/capirca/issues/348 - zero_ip_address: list[nacaddr.IPType] = [] + zero_ip_address: List[nacaddr.IPType] = [] if af == 4: - zero_ip_address: list[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')] + zero_ip_address: List[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')] # source address if self.term.source_address: - source_address: list[nacaddr.IPType] = self.term.GetAddressOfVersion( + source_address: List[nacaddr.IPType] = self.term.GetAddressOfVersion( 'source_address', af) - source_address_exclude: list[nacaddr.IPType] = ( + source_address_exclude: List[nacaddr.IPType] = ( self.term.GetAddressOfVersion('source_address_exclude', af)) if source_address_exclude: - source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs( + source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs( source_address, source_address_exclude + zero_ip_address) else: @@ -277,12 +278,12 @@ def __str__(self): '0.0.0.0/0' not in [str(a) for a in source_address]): # Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However, # do so only if we would not already have 'ANY' in the list. - source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs( + source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs( source_address, zero_ip_address) if source_address: if af == 4: - source_address: list[nacaddr.IPv4] - source_v4_addr: list[nacaddr.IPv4] = source_address + source_address: List[nacaddr.IPv4] + source_v4_addr: List[nacaddr.IPv4] = source_address if (source_v4_addr and '0.0.0.0/0' in [str(a) for a in source_address]): # Once we make the address list empty, it'll be set to ANY later @@ -292,22 +293,22 @@ def __str__(self): # later, we'll correctly not use ANY.) # # See https://github.com/google/capirca/issues/348 - source_v4_addr: list[nacaddr.IPv4] = [] + source_v4_addr: List[nacaddr.IPv4] = [] else: - source_address: list[nacaddr.IPv6] - source_v6_addr: list[nacaddr.IPv6] = source_address + source_address: List[nacaddr.IPv6] + source_v6_addr: List[nacaddr.IPv6] = source_address source_addr = source_v4_addr + source_v6_addr # destination address if self.term.destination_address: - destination_address: list[ + destination_address: List[ nacaddr.IPType] = self.term.GetAddressOfVersion( 'destination_address', af) - destination_address_exclude: list[nacaddr.IPType] = ( + destination_address_exclude: List[nacaddr.IPType] = ( self.term.GetAddressOfVersion('destination_address_exclude', af)) if destination_address_exclude: - destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs( + destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs( destination_address, destination_address_exclude + zero_ip_address) else: @@ -315,12 +316,12 @@ def __str__(self): '0.0.0.0/0' not in [str(a) for a in source_address]): # Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However, # do so only if we would not already have 'ANY' in the list. - destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs( + destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs( destination_address, zero_ip_address) if destination_address: if af == 4: - destination_address: list[nacaddr.IPv4] - dest_v4_addr: list[nacaddr.IPv4] = destination_address + destination_address: List[nacaddr.IPv4] + dest_v4_addr: List[nacaddr.IPv4] = destination_address if (dest_v4_addr and '0.0.0.0/0' in [str(a) for a in destination_address]): # Once we make the address list empty, it'll be set to ANY later @@ -330,10 +331,10 @@ def __str__(self): # later, we'll correctly not use ANY.) # # See https://github.com/google/capirca/issues/348 - dest_v4_addr: list[nacaddr.IPv4] = [] + dest_v4_addr: List[nacaddr.IPv4] = [] else: - destination_address: list[nacaddr.IPv6] - dest_v6_addr: list[nacaddr.IPv6] = destination_address + destination_address: List[nacaddr.IPv6] + dest_v6_addr: List[nacaddr.IPv6] = destination_address destination_addr = dest_v4_addr + dest_v6_addr # Check for mismatch IP for source and destination address for mixed filter @@ -420,13 +421,13 @@ class Nsxt(aclgenerator.ACLGenerator): _FILTER_OPTIONS_DICT = {} def _TranslatePolicy(self, pol: policy.Policy, exp_info: int): - self.nsxt_policies: list[Tuple[policy.Header, str, list[Term]]] = [] + self.nsxt_policies: List[Tuple[policy.Header, str, List[Term]]] = [] current_date = datetime.datetime.utcnow().date() # Warn about policies that will expire in less than exp_info weeks. exp_info_date = current_date + datetime.timedelta(weeks=exp_info) - filters: list[Tuple[policy.Header, list[policy.Term]]] = pol.filters + filters: List[Tuple[policy.Header, List[policy.Term]]] = pol.filters for header, terms in filters: if self._PLATFORM not in header.platforms: continue diff --git a/requirements.txt b/requirements.txt index fbf3f505..14b345de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ absl-py ply PyYAML six>=1.12.0 +typing_extensions diff --git a/tests/lib/nsxt_test.py b/tests/lib/nsxt_test.py index ca1b0e36..79277ae1 100644 --- a/tests/lib/nsxt_test.py +++ b/tests/lib/nsxt_test.py @@ -16,7 +16,7 @@ import copy import json -from typing import Any, Literal, Tuple, Union +from typing import Any, Tuple, Union, Dict, List from unittest import mock from absl.testing import absltest @@ -25,6 +25,7 @@ from capirca.lib import naming from capirca.lib import nsxt from capirca.lib import policy +from typing_extensions import Literal ICMPV6_TERM = """\ @@ -864,14 +865,14 @@ class TestTrafficKindGrid(parameterized.TestCase): # Which address set should be put into the policy, based on the type of policy # we're testing? - KIND_TO_ADDRESS: dict[_TRAFFIC_KIND, _ADDRESSES] = { + KIND_TO_ADDRESS: Dict[_TRAFFIC_KIND, _ADDRESSES] = { 'mixed': 'GOOGLE_DNS', 'v4': 'INTERNAL_V4', 'v6': 'INTERNAL_V6'} # Which expanded address group (e.g. netblocks) is expected, based on the type # of policy we're testing? - KIND_TO_ADDRESS_GROUPS: dict[ + KIND_TO_ADDRESS_GROUPS: Dict[ _TRAFFIC_KIND, Union[nacaddr.IPv4, nacaddr.IPv6, Literal['ANY']]] = { # 'GOOGLE_DNS' 'mixed': [nacaddr.IP('8.8.4.4/32'), nacaddr.IP('8.8.8.8/32'), @@ -961,11 +962,11 @@ def test_generator_works(self): ' destination-address:: INTERNAL_V6', '}'])) - def get_source_dest_addresses(self, nsxt_json: dict[str, Any]) -> ( - Tuple[list[str], list[str]]): - rules: list[dict[str, Any]] = nsxt_json['rules'] - src: list[str] = [] - dst: list[str] = [] + def get_source_dest_addresses(self, nsxt_json: Dict[str, Any]) -> ( + Tuple[List[str], List[str]]): + rules: List[Dict[str, Any]] = nsxt_json['rules'] + src: List[str] = [] + dst: List[str] = [] for rule in rules: src.extend(i for i in rule['source_groups'])