diff --git a/senza/components/elastigroup.py b/senza/components/elastigroup.py index 59e015a9..38e87ffe 100644 --- a/senza/components/elastigroup.py +++ b/senza/components/elastigroup.py @@ -51,7 +51,7 @@ def component_elastigroup(definition, configuration, args, info, force, account_ ensure_default_product(elastigroup_config) ensure_instance_monitoring(elastigroup_config) - extract_subnets(definition, elastigroup_config, account_info) + extract_subnets(configuration, elastigroup_config, account_info) extract_user_data(configuration, elastigroup_config, info, force, account_info) extract_load_balancer_name(configuration, elastigroup_config) extract_public_ips(configuration, elastigroup_config) @@ -264,7 +264,7 @@ def fill_standard_tags(definition, elastigroup_config): elastigroup_config["name"] = full_name -def extract_subnets(definition, elastigroup_config, account_info): +def extract_subnets(configuration, elastigroup_config, account_info): """ This fills in the subnetIds and region attributes of the Spotinst elastigroup, in case they're not defined already The subnetIds are discovered by Senza::StupsAutoConfiguration and the region is provided by the AccountInfo object @@ -273,9 +273,9 @@ def extract_subnets(definition, elastigroup_config, account_info): subnet_ids = elastigroup_config["compute"].get("subnetIds", []) target_region = elastigroup_config.get("region", account_info.Region) if not subnet_ids: - subnet_ids = definition["Mappings"]["ServerSubnets"].get(target_region, {}).get("Subnets", []) + subnet_set = "LoadBalancerSubnets" if configuration.get("AssociatePublicIpAddress", False) else "ServerSubnets" + elastigroup_config["compute"]["subnetIds"] = {"Fn::FindInMap": [subnet_set, {"Ref": "AWS::Region"}, "Subnets"]} elastigroup_config["region"] = target_region - elastigroup_config["compute"]["subnetIds"] = subnet_ids def extract_user_data(configuration, elastigroup_config, info: dict, force, account_info): diff --git a/tests/test_elastigroup.py b/tests/test_elastigroup.py index 41f98f58..30344be7 100644 --- a/tests/test_elastigroup.py +++ b/tests/test_elastigroup.py @@ -62,7 +62,7 @@ def test_component_elastigroup_defaults(monkeypatch): assert {'tagKey': 'StackName', 'tagValue': 'foobar'} in tags assert {'tagKey': 'StackVersion', 'tagValue': '0.1'} in tags assert properties["group"]["compute"]["product"] == ELASTIGROUP_DEFAULT_PRODUCT - assert properties["group"]["compute"]["subnetIds"] == subnets + assert properties["group"]["compute"]["subnetIds"] == {"Fn::FindInMap": ["ServerSubnets", {"Ref": "AWS::Region"}, "Subnets"]} assert properties["group"]["region"] == "reg1" assert properties["group"]["strategy"] == ELASTIGROUP_DEFAULT_STRATEGY @@ -381,21 +381,22 @@ def test_standard_tags(): def test_extract_subnets(): test_cases = [ - { # use auto discovered subnets from auto-discovered region - "definition": {"Mappings": {"ServerSubnets": {"reg1": {"Subnets": ["sn1", "sn2"]}}}}, + { # use auto discovered server subnets + "component_config": {}, "given_config": {}, - "expected_config": {"compute": {"subnetIds": ["sn1", "sn2"]}, "region": "reg1"}, + "expected_config": { + "compute": {"subnetIds": {"Fn::FindInMap": ["ServerSubnets", {"Ref": "AWS::Region"}, "Subnets"]}}, + "region": "reg1"}, }, - { # use auto discovered subnets from specified region - "definition": {"Mappings": {"ServerSubnets": { - "reg1": {"Subnets": ["sn1", "sn2"]}, - "reg2": {"Subnets": ["n1", "n2"]} - }}}, - "given_config": {"region": "reg2"}, - "expected_config": {"compute": {"subnetIds": ["n1", "n2"]}, "region": "reg2"}, + { # use auto discovered DMZ subnets + "component_config": {"AssociatePublicIpAddress": True}, + "given_config": {}, + "expected_config": { + "compute": {"subnetIds": {"Fn::FindInMap": ["LoadBalancerSubnets", {"Ref": "AWS::Region"}, "Subnets"]}}, + "region": "reg1"}, }, { # leave subnetIds untouched - "definition": {"Mappings": {"ServerSubnets": {"reg1": {"Subnets": ["sn1", "sn2"]}}}}, + "component_config": {}, "given_config": {"compute": {"subnetIds": ["subnet01"]}}, "expected_config": {"compute": {"subnetIds": ["subnet01"]}, "region": "reg1"}, }, @@ -403,9 +404,10 @@ def test_extract_subnets(): account_info = MagicMock() account_info.Region = "reg1" for test_case in test_cases: - got = test_case["given_config"] - extract_subnets(test_case["definition"], got, account_info) - assert test_case["expected_config"] == got + input = test_case["given_config"] + config = test_case["component_config"] + extract_subnets(config, input, account_info) + assert test_case["expected_config"] == input def test_load_balancers():