Skip to content

Commit

Permalink
Merge pull request #353 from ivucica:issue345_nsxt_typing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597521837
  • Loading branch information
Capirca Team committed Jan 11, 2024
2 parents 102b025 + badaa97 commit d7fbfa9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 36 deletions.
57 changes: 29 additions & 28 deletions capirca/lib/nsxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import datetime
import json
from typing import Literal, TypedDict, Optional, Union, Tuple
from typing import Optional, Union, Tuple, List

from absl import logging
from capirca.lib import aclgenerator
from capirca.lib import nacaddr
from capirca.lib import policy # for typing information
from typing_extensions import Literal, TypedDict # pylint: disable=g-multiple-import

_ACTION_TABLE = {
'accept': 'ALLOW',
Expand Down Expand Up @@ -90,9 +91,9 @@ class NsxtUnsupportedManyPoliciesError(Error):
class ServiceEntries:
"""Represents service entries for a rule."""

def __init__(self, protocol: int, source_ports: list[Tuple[str, str]],
destination_ports: list[Tuple[str, str]],
icmp_types: list[int]):
def __init__(self, protocol: int, source_ports: List[Tuple[str, str]],
destination_ports: List[Tuple[str, str]],
icmp_types: List[int]):
"""Setting things up.
Args:
Expand Down Expand Up @@ -239,8 +240,8 @@ def __str__(self):
af_list = [self.af]

# There can be many source and destination addresses.
source_address: list[nacaddr.IPType] = []
destination_address: list[nacaddr.IPType] = []
source_address: List[nacaddr.IPType] = []
destination_address: List[nacaddr.IPType] = []
source_addr = []
destination_addr = []

Expand All @@ -257,32 +258,32 @@ def __str__(self):
# cannot be a part of a netblock passed into NSX-T API. Currently only
# addressing IPv4 as that's where the issue has been identified.
# https://github.com/google/capirca/issues/348
zero_ip_address: list[nacaddr.IPType] = []
zero_ip_address: List[nacaddr.IPType] = []
if af == 4:
zero_ip_address: list[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')]
zero_ip_address: List[nacaddr.IPType] = [nacaddr.IP('0.0.0.0/32')]

# source address
if self.term.source_address:
source_address: list[nacaddr.IPType] = self.term.GetAddressOfVersion(
source_address: List[nacaddr.IPType] = self.term.GetAddressOfVersion(
'source_address', af)
source_address_exclude: list[nacaddr.IPType] = (
source_address_exclude: List[nacaddr.IPType] = (
self.term.GetAddressOfVersion('source_address_exclude', af))

if source_address_exclude:
source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address,
source_address_exclude + zero_ip_address)
else:
if (af == 4 and source_address and
'0.0.0.0/0' not in [str(a) for a in source_address]):
# Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However,
# do so only if we would not already have 'ANY' in the list.
source_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
source_address, zero_ip_address)
if source_address:
if af == 4:
source_address: list[nacaddr.IPv4]
source_v4_addr: list[nacaddr.IPv4] = source_address
source_address: List[nacaddr.IPv4]
source_v4_addr: List[nacaddr.IPv4] = source_address
if (source_v4_addr and
'0.0.0.0/0' in [str(a) for a in source_address]):
# Once we make the address list empty, it'll be set to ANY later
Expand All @@ -292,35 +293,35 @@ def __str__(self):
# later, we'll correctly not use ANY.)
#
# See https://github.com/google/capirca/issues/348
source_v4_addr: list[nacaddr.IPv4] = []
source_v4_addr: List[nacaddr.IPv4] = []
else:
source_address: list[nacaddr.IPv6]
source_v6_addr: list[nacaddr.IPv6] = source_address
source_address: List[nacaddr.IPv6]
source_v6_addr: List[nacaddr.IPv6] = source_address
source_addr = source_v4_addr + source_v6_addr

# destination address
if self.term.destination_address:
destination_address: list[
destination_address: List[
nacaddr.IPType] = self.term.GetAddressOfVersion(
'destination_address', af)
destination_address_exclude: list[nacaddr.IPType] = (
destination_address_exclude: List[nacaddr.IPType] = (
self.term.GetAddressOfVersion('destination_address_exclude', af))

if destination_address_exclude:
destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address,
destination_address_exclude + zero_ip_address)
else:
if (af == 4 and source_address and
'0.0.0.0/0' not in [str(a) for a in source_address]):
# Exclude 0.0.0.0/32, removing 0.0.0.0/anything netblocks. However,
# do so only if we would not already have 'ANY' in the list.
destination_address: list[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address: List[nacaddr.IPType] = nacaddr.ExcludeAddrs(
destination_address, zero_ip_address)
if destination_address:
if af == 4:
destination_address: list[nacaddr.IPv4]
dest_v4_addr: list[nacaddr.IPv4] = destination_address
destination_address: List[nacaddr.IPv4]
dest_v4_addr: List[nacaddr.IPv4] = destination_address
if (dest_v4_addr and
'0.0.0.0/0' in [str(a) for a in destination_address]):
# Once we make the address list empty, it'll be set to ANY later
Expand All @@ -330,10 +331,10 @@ def __str__(self):
# later, we'll correctly not use ANY.)
#
# See https://github.com/google/capirca/issues/348
dest_v4_addr: list[nacaddr.IPv4] = []
dest_v4_addr: List[nacaddr.IPv4] = []
else:
destination_address: list[nacaddr.IPv6]
dest_v6_addr: list[nacaddr.IPv6] = destination_address
destination_address: List[nacaddr.IPv6]
dest_v6_addr: List[nacaddr.IPv6] = destination_address
destination_addr = dest_v4_addr + dest_v6_addr

# Check for mismatch IP for source and destination address for mixed filter
Expand Down Expand Up @@ -420,13 +421,13 @@ class Nsxt(aclgenerator.ACLGenerator):
_FILTER_OPTIONS_DICT = {}

def _TranslatePolicy(self, pol: policy.Policy, exp_info: int):
self.nsxt_policies: list[Tuple[policy.Header, str, list[Term]]] = []
self.nsxt_policies: List[Tuple[policy.Header, str, List[Term]]] = []
current_date = datetime.datetime.utcnow().date()

# Warn about policies that will expire in less than exp_info weeks.
exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

filters: list[Tuple[policy.Header, list[policy.Term]]] = pol.filters
filters: List[Tuple[policy.Header, List[policy.Term]]] = pol.filters
for header, terms in filters:
if self._PLATFORM not in header.platforms:
continue
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ absl-py
ply
PyYAML
six>=1.12.0
typing_extensions
17 changes: 9 additions & 8 deletions tests/lib/nsxt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import copy
import json
from typing import Any, Literal, Tuple, Union
from typing import Any, Tuple, Union, Dict, List
from unittest import mock

from absl.testing import absltest
Expand All @@ -25,6 +25,7 @@
from capirca.lib import naming
from capirca.lib import nsxt
from capirca.lib import policy
from typing_extensions import Literal


ICMPV6_TERM = """\
Expand Down Expand Up @@ -864,14 +865,14 @@ class TestTrafficKindGrid(parameterized.TestCase):

# Which address set should be put into the policy, based on the type of policy
# we're testing?
KIND_TO_ADDRESS: dict[_TRAFFIC_KIND, _ADDRESSES] = {
KIND_TO_ADDRESS: Dict[_TRAFFIC_KIND, _ADDRESSES] = {
'mixed': 'GOOGLE_DNS',
'v4': 'INTERNAL_V4',
'v6': 'INTERNAL_V6'}

# Which expanded address group (e.g. netblocks) is expected, based on the type
# of policy we're testing?
KIND_TO_ADDRESS_GROUPS: dict[
KIND_TO_ADDRESS_GROUPS: Dict[
_TRAFFIC_KIND, Union[nacaddr.IPv4, nacaddr.IPv6, Literal['ANY']]] = {
# 'GOOGLE_DNS'
'mixed': [nacaddr.IP('8.8.4.4/32'), nacaddr.IP('8.8.8.8/32'),
Expand Down Expand Up @@ -961,11 +962,11 @@ def test_generator_works(self):
' destination-address:: INTERNAL_V6',
'}']))

def get_source_dest_addresses(self, nsxt_json: dict[str, Any]) -> (
Tuple[list[str], list[str]]):
rules: list[dict[str, Any]] = nsxt_json['rules']
src: list[str] = []
dst: list[str] = []
def get_source_dest_addresses(self, nsxt_json: Dict[str, Any]) -> (
Tuple[List[str], List[str]]):
rules: List[Dict[str, Any]] = nsxt_json['rules']
src: List[str] = []
dst: List[str] = []

for rule in rules:
src.extend(i for i in rule['source_groups'])
Expand Down

0 comments on commit d7fbfa9

Please sign in to comment.