From d90b0302edb4e14a7aa0e58a62a79bab681107f7 Mon Sep 17 00:00:00 2001
From: wert310 <310wert@gmail.com>
Date: Mon, 23 Nov 2020 11:52:32 +0100
Subject: [PATCH] Fix Compiler.

---
 frontend/src/App.vue                |  25 +++-
 fwsynthesizer/compile/__init__.py   |   1 +
 fwsynthesizer/synthesis/__init__.py |  31 ++++-
 fwsynthesizer/synthesis/query.py    |  11 +-
 fwsynthesizer/web/__init__.py       | 207 +++++++++++++++++++++++++++-
 5 files changed, 262 insertions(+), 13 deletions(-)

diff --git a/frontend/src/App.vue b/frontend/src/App.vue
index 9462c8c..425a08e 100644
--- a/frontend/src/App.vue
+++ b/frontend/src/App.vue
@@ -31,7 +31,7 @@
               <b-icon pack="fa" :icon="active ? 'angle-up' : 'angle-down'"></b-icon>
             </button>
 
-            <b-dropdown-item v-for="target in frontends" v-bind:key="target" aria-role="listitem" @click="compilerCompile(target)">{{ target }}</b-dropdown-item>
+            <b-dropdown-item v-for="target in targets" v-bind:key="target" aria-role="listitem" @click="compilerCompile(target)">{{ target }}</b-dropdown-item>
         </b-dropdown>
 
 
@@ -81,7 +81,7 @@
              <h1 v-if="isWorking && query_progress > 0" class="is-size-6 has-text-centered has-text-weight-bold is-family-monospace">Synthesizing policy...</h1>
              <b-progress class="mt-3 ml-3 mr-3 mb-5" :value="query_progress" show-value format="percent" v-if="isWorking && query_progress > 0"></b-progress>
 
-             <div v-for="mode in Object.keys(fwspolicy)" v-bind:key="mode">
+             <div v-for="mode in Object.keys(fwspolicy).filter(n => n != 'locals')" v-bind:key="mode">
                <h1 class="is-size-5 has-text-weight-bold is-family-monospace">{{ mode.toUpperCase() }}</h1>
 
                 <table class="fws-table singleline" v-if="mode == 'aliases'">
@@ -379,16 +379,18 @@ export default {
             if (this.getCurrentMode() != 'compiler') return
             this.fwspolicy = {}
             var query_code_backup = this.query_code
-            this.query_code = `table_style json\naliases(${policy})\nsynthesis(${policy})\n`
+            this.query_code = `table_style json\nlocals(${policy})\naliases(${policy})\nsynthesis(${policy})\n`
             this.queryRun().then(() => {
                 const sregex = /FORWARD\n\n(\{.*\})\n?(\{.*\}?)\n\nINPUT\n\n(\{.*\})(\n\{.*\})?\n\nOUTPUT\n\n(\{.*\})(\n\{.*\})?\n\nLOOPBACK\n\n(\{.*\})(\n\{.*\})?/
                 const aregex = /([a-zA-Z0-9_-]+): ([0-9./]+)/g
+                const lregex = /local ([0-9./]+)/g
                 console.log(this.query_output)
                 const match = this.query_output.match(sregex)
                 if (!match)
                     this.showError(this.query_output.replaceAll("<", "&lt;"))
                 else {
                     this.fwspolicy = {
+                        'locals':   [...this.query_output.matchAll(lregex)].map(x => x[1]),
                         'aliases':  [...this.query_output.matchAll(aregex)].map(x => [x[1], x[2]]),
                         'forward':  Array.prototype.concat([JSON.parse(match[1])], (match[2] ? [JSON.parse(match[2])] : [])),
                         'input':    Array.prototype.concat([JSON.parse(match[3])], (match[4] ? [JSON.parse(match[4])] : [])),
@@ -440,8 +442,15 @@ export default {
                 }).then(b => b.json())
                 .then(res => {
                     console.log(res)
-                    this.isWorking = false
-                    // TODO
+                    const blob = new Blob([res.value], {type: 'text/plain'})
+                    const e = document.createEvent('MouseEvents'),
+                          a = document.createElement('a');
+                    a.download = `${target}_policy_${new Date().toJSON()}.rules`;
+                    a.href = window.URL.createObjectURL(blob);
+                    a.dataset.downloadurl = ['text/plain', a.download, a.href].join(':');
+                    e.initEvent('click', true, false, window, 0, 0, 0, 0, 0, false, false, false, false, 0, null);
+                    a.dispatchEvent(e);
+                    this.isWorking = false;
                 }).catch(this.showError)
 
         },
@@ -485,6 +494,11 @@ export default {
                     console.log(res)
                     this.frontends = res
                 }).catch(this.showError)
+            fetch(`${FWS_URI}/compiler/targets`).then(b => b.json())
+                .then(res => {
+                    console.log(res)
+                    this.targets = res
+                }).catch(this.showError)
         },
 
     },
@@ -513,6 +527,7 @@ export default {
             loaded_policies: [],
             isWorking: true,
             frontends: [],
+            targets: [],
             fws_instance: null,
             query_code: "",
             query_output: "",
diff --git a/fwsynthesizer/compile/__init__.py b/fwsynthesizer/compile/__init__.py
index f4761b5..a2901aa 100644
--- a/fwsynthesizer/compile/__init__.py
+++ b/fwsynthesizer/compile/__init__.py
@@ -4,6 +4,7 @@
 from compile_ipfw import *
 from compile_pf import *
 
+TARGETS = ['iptables', 'ipfw', 'pf']
 
 def fw_compile(semantics, target):
     """
diff --git a/fwsynthesizer/synthesis/__init__.py b/fwsynthesizer/synthesis/__init__.py
index af00555..f5cf596 100644
--- a/fwsynthesizer/synthesis/__init__.py
+++ b/fwsynthesizer/synthesis/__init__.py
@@ -149,6 +149,19 @@ def any_protocol(self):
     def any_mark(self):
         return self.mark[0][0] == Any_tag[0][0] and self.mark[0][1] == Any_tag[0][1]
 
+    def to_mrule_packet(self):
+        return [ #  srcIp, srcPort, dstIp, dstPort, srcMac, dstMac, protocol, state, mark
+            [ [struct.unpack(">I", ip.packed)[0] for ip in ips] for ips in self.srcIp ],
+            self.srcPort,
+            [ [struct.unpack(">I", ip.packed)[0] for ip in ips] for ips in self.dstIp ],
+            self.dstPort,
+            [ [mac._mac for mac in macs] for macs in self.srcMac ],
+            [ [mac._mac for mac in macs] for macs in self.dstMac ],
+            self.protocol,
+            self.state,
+            self.mark
+        ]
+
 
 class Rule(object):
     "FWS Rule Object"
@@ -178,17 +191,24 @@ def __init__(self, packet_in, packet_out):
     def __repr__(self):
         return "#<Rule {} {} {}>".format(self.type, self.packet_in, self.packet_out)
 
+    def to_mrule(self):
+        return [self.packet_in.to_mrule_packet(), self.packet_out.to_mrule_packet()]
+
 
 class SynthesisOutput:
     "Firewall synthesis output"
 
-    def __init__(self, fw, rules):
+    def __init__(self, fw, rules, mrules_precomputed=False):
         self.firewall = fw
         self.__rules = rules
+        self.mrules_precomputed = mrules_precomputed
 
     def get_rules(self):
         "Get the rules as lists of Rule objects"
-        rules = [ Synthesis.mrule_list(r) for r in self.__rules ]
+        if self.mrules_precomputed:
+            rules = self.__rules
+        else:
+            rules = [ Synthesis.mrule_list(r) for r in self.__rules ]
         return [ Rule(Packet(*pin), Packet(*pout)) for pin,pout in rules ]
 
     def print_table(self, table_style=TableStyle.UNICODE, local_src=LocalFlag.BOTH,
@@ -209,12 +229,15 @@ def print_table(self, table_style=TableStyle.UNICODE, local_src=LocalFlag.BOTH,
         hide_nats = nat == NatFlag.FILTER
         hide_filters = nat == NatFlag.NAT
         table_printer.print_table(
-            rules, table_style, [ipaddr.IPv4Address(a) for a in self.firewall.locals],
+            rules, table_style, [ipaddr.IPv4Address(a) for a in self.firewall.locals] if self.firewall else [],
             hide_src, hide_dst, hide_nats, hide_filters,
             projection, aliases=aliases)
 
     def get_rules_no_duplicates(self):
-        rules = [Synthesis.mrule_list(r) for r in self.__rules]
+        if self.mrules_precomputed:
+            rules = self.__rules
+        else:
+            rules = [Synthesis.mrule_list(r) for r in self.__rules]
 
         for rule in rules:
             for pkt in rule:
diff --git a/fwsynthesizer/synthesis/query.py b/fwsynthesizer/synthesis/query.py
index c98ddef..f2d8c2b 100644
--- a/fwsynthesizer/synthesis/query.py
+++ b/fwsynthesizer/synthesis/query.py
@@ -296,6 +296,14 @@ def eval(self, fws):
             print "{}: {}".format(a, aliases[a])
         print
 
+class Locals(FWSCmd, namedtuple('Locals', ['p'])):
+    def eval(self, fws):
+        policy = fws.get_variable(self.p)
+        locals_ = policy.firewall.locals
+        for ip in locals_:
+            print "local {}".format(ip)
+        print
+        
 class Porting(FWSCmd, namedtuple('Porting', ['p', 'target', 'file'])):
     def eval(self, fws):
         policy = fws.get_variable(self.p)
@@ -369,6 +377,7 @@ def eval(self, fws):
 echo = (sym('echo') >> litstr).parsecmap(Echo)
 ifcl = sym('ifcl') >> parens(identifier).parsecmap(lambda p: Ifcl(p))
 aliases = sym('aliases') >> parens(identifier).parsecmap(lambda p: Aliases(p))
+locals_ = sym('locals') >> parens(identifier).parsecmap(lambda p: Locals(p))
 setting = (sym('help').parsecmap(lambda _: Echo(help_message)) ^
            sym('show_time').parsecmap(lambda _: ShowTime()) ^
            sym('verbose_mode').parsecmap(lambda _: VerboseMode()) ^
@@ -423,7 +432,7 @@ def load_policy():
 
 @generate
 def fws_command():
-    cmd = yield ( echo ^ setting ^ aliases ^ porting ^ comparison ^ synthesis ^
+    cmd = yield ( echo ^ setting ^ aliases ^ locals_ ^ porting ^ comparison ^ synthesis ^
                   related ^ ifcl ^ load_policy ^ comment.parsecmap(lambda _: Nop()) ^
                   identifier.parsecmap(lambda n: ShowIdentifier(n)) )
     preturn ( cmd )
diff --git a/fwsynthesizer/web/__init__.py b/fwsynthesizer/web/__init__.py
index f348388..bdd5e10 100644
--- a/fwsynthesizer/web/__init__.py
+++ b/fwsynthesizer/web/__init__.py
@@ -2,6 +2,7 @@
 
 import fwsynthesizer
 from fwsynthesizer.frontends import FRONTENDS
+from fwsynthesizer.compile import TARGETS
 
 from parsec import *
 from fwsynthesizer.parsers.utils import *
@@ -17,6 +18,7 @@
 import time
 import threading
 import logging
+import json
 
 libc = ctypes.CDLL(None)
 c_stdout = ctypes.c_void_p.in_dll(libc, 'stdout')
@@ -75,6 +77,10 @@ def after_request(response):
 def list_frontends():
     return jsonify(FRONTENDS)
 
+@app.route('/compiler/targets')
+def list_targets():
+    return jsonify(TARGETS)
+
 @app.route('/new_repl')
 def new_interpreter():
     global active_interpreters
@@ -143,11 +149,14 @@ def generate():
 def translate_tables():
     args = request.json
     target = args['target']
-    fwspolicy = args['fwspolicy']
+    fwspolicy = json.loads(args['fwspolicy'])
+    
 
-    # TODO
+    mrules = policy_to_mrules(fwspolicy)
+    semantics = fwsynthesizer.SynthesisOutput(None, mrules, mrules_precomputed=True)
+    configuration = fw_compile(semantics, target)
     
-    return jsonify({'value': None})
+    return jsonify({'value': configuration})
 
 @app.route('/')
 def index():
@@ -167,3 +176,195 @@ def start_app(host="localhost", port="5095"):
         threaded=True,
         processes=1,
     )
+
+################################################################################
+## Compiler
+
+def policy_to_mrules(fwspolicy):
+    aliases = fwspolicy['aliases']
+    local_addresses = [ [x,x] for x in
+                        [ struct.unpack(">I", ipaddr.IPv4Address(ip).packed)[0] for ip in fwspolicy['locals'] ]]
+
+    def remove_docs(intervals):
+        if isinstance(intervals, DOC):
+            new_ints = intervals.to_cubes()
+        else:
+            new_ints = []
+            for i in intervals:
+                if isinstance(i, DOC):
+                    new_ints.extend(i.to_cubes())
+                else:
+                    new_ints.append(i)
+        return new_ints
+    
+    def remove_locals_if(mode, modes, intervals):
+        if mode not in modes:
+            return intervals
+
+        if isinstance(intervals, DOC):
+            intervals.diffs.extend(local_addresses)
+            return intervals
+        new_ints = []
+        for i in intervals:
+            new_ints.append(DOC(i, local_addresses))
+        return new_ints
+
+
+    def expand_and_parse(tables, mode):
+        mrules = []
+        
+        for rules in tables:
+            is_snat = not all("SNAT" not in f for f in rules['field_names'])
+            is_dnat = not all("DNAT" not in f for f in rules['field_names'])
+    
+            for r in rules['table']:
+                for rep in aliases:
+                    for field in r:
+                        r[field] = r[field].replace(*rep)
+
+                pin = map(remove_docs, [
+                    # srcIp, srcPort, dstIp, dstPort, srcMac, dstMac, protocol, state[, mark]   
+                    remove_locals_if(
+                        mode, ['forward', 'input'],
+                        interval_parser(ip_range_parser(), [0, 2**32-1]).parse_strict(r['srcIp'].encode())),
+                    interval_parser(port_parser, [0, 2**16-1]).parse_strict(r['srcPort'].encode()),
+                    remove_locals_if(
+                        mode, ['forward', 'output'] if not ("dstIp'" in r and r["dstIp'"].strip() != '-') else [],
+                        interval_parser(ip_range_parser(), [0, 2**32-1]).parse_strict(r['dstIp'].encode())),
+                    interval_parser(port_parser, [0, 2**16-1]).parse_strict(r['dstPort'].encode()),
+                    interval_parser(mac_parser, [0, 2**48-1]).parse_strict(r['srcMac'].encode()),
+                    interval_parser(mac_parser, [0, 2**48-1]).parse_strict(r['dstMac'].encode()),
+                    interval_parser(protocol_parser, [0, 255]).parse_strict(r['protocol'].encode()),
+                    interval_parser(state_parser, [0,1]).parse_strict(r['state'].encode()),
+                ])
+                pout = [[], [], [], [], [], [], [], []]
+                
+                if is_snat or is_dnat:
+                    # nat
+                    pout = map(remove_docs, [
+                        interval_parser(ip_range_parser(), [0, 2**32-1]).parse_strict(r["srcIp'"].encode())
+                        if "srcIp'" in r and r["srcIp'"].strip() != "-" else [],
+                        interval_parser(port_parser, [0, 2**16-1]).parse_strict(r["srcPort'"].encode())
+                        if "srcPort'" in r and r["srcPort'"].strip() != "-" else [],
+                        remove_locals_if(mode, ['forward', 'output'], interval_parser(ip_range_parser(), [0, 2**32-1]).parse_strict(r["dstIp'"].encode()))
+                        if "dstIp'" in r and r["dstIp'"].strip() != "-" else [],
+                        interval_parser(port_parser, [0, 2**16-1]).parse_strict(r["dstPort'"].encode())
+                        if "dstPort'" in r and r["dstPort'"].strip() != "-" else [],
+                        [], [], [], []])
+
+                mrules.append([pin, pout])
+        return mrules
+
+
+    rules = []
+    rules.extend( expand_and_parse(fwspolicy['forward'], 'forward') )
+    rules.extend( expand_and_parse(fwspolicy['input'], 'input') )
+    rules.extend( expand_and_parse(fwspolicy['output'], 'output') )
+    rules.extend( expand_and_parse(fwspolicy['loopback'], 'loopback') )
+    
+    return rules
+
+
+def ip_range_parser():
+
+    def to_interval(b):
+        if isinstance(b, ipaddr.IPv4Network):
+            return b.ip, ipaddr.IPv4Address(b._ip | (0xffffffff >> b.prefixlen))
+        if isinstance(b, ipaddr_ext.IPv4Range):
+            return b.ip_from, b.ip_to
+        if isinstance(b, ipaddr.IPv4Address):
+            return b, b
+        raise NotImplemented
+
+    def make_int_interval(elm):
+        a, b = to_interval(elm)
+        return [ struct.unpack(">I", a.packed)[0], struct.unpack(">I", b.packed)[0] ]
+    
+    return (ip_range ^ ip_subnet ^ ip_addr).parsecmap(make_int_interval)
+
+protos = {name: int(proto) for name, proto in utils.protocols().items()}
+states = {'NEW':0, 'ESTABLISHED': 1}
+
+def hr_parser(mappings):
+    
+    def name_to_interval(e):
+        v = mappings[e]
+        return [v,v]
+
+    return regex('[a-zA-Z]+').parsecmap(name_to_interval)
+
+protocol_parser = hr_parser(protos)
+state_parser = hr_parser(states)
+port_parser = (port_range ^ port).parsecmap(
+    lambda p:
+    [int(p.value), int(p.value)] if isinstance(p, Port) else \
+    [int(p.bottom), int(p.top)] if isinstance(p, PortRange) else [])
+
+mac_parser = mac_addr.parsecmap(lambda mac: [mac._mac, mac._mac])
+
+class DOC(object):
+    def __init__(self, rng, diffs):
+        self.range = rng
+        self.diffs = diffs
+
+    def __repr__(self):
+        return 'DOC(range={}, diffs={})'.format(self.range, self.diffs)
+    
+    def to_cubes(self):
+        min_, max_ = self.range
+        sgaps = sorted(filter(lambda x: min(*x) >= min_ and max(*x) <= max_, self.diffs),
+                      key=lambda x: x[0])
+        # gaps need to be mutually exclusive
+
+        gaps = []
+        for gap in sgaps:
+            (gb, gt) = gap
+
+            if gaps == []:
+                gaps.append(gap)
+            else:
+                (nb, nt) = gaps[-1]
+                if gb >= nb  and gb <= nt and gt <= nt:
+                    # included, do nothing
+                    pass
+                elif gb >= nb and gb <= nt and gt >= nt:
+                    # overlapping, replace nt
+                    gaps[-1] = [nb, gt]
+                else:
+                    gaps.append(gap)
+            
+        ints = [[min_, max_]]
+        for (bottom, top) in gaps:
+            last_min, last_max = ints[-1]
+
+            if bottom == last_min:
+                assert top+1 < last_max
+                ints[-1] = [top+1, last_max]
+            elif top == last_max:
+                assert bottom-1 > last_min
+                ints[-1] = [last_min, bottom-1]
+            else:
+                ints[-1] = [last_min, bottom-1]
+                ints.append([top+1, last_max])
+
+        for (bottom, top) in gaps:
+            assert (bottom <= top)
+        for (bottom, top) in ints:
+            assert (bottom <= top)
+        return ints
+                
+
+def interval_parser(elem_parser, elem_all): # -> Parser[ Union[DOC, List[Intervals]] ]
+    token_endl = lambda p: many(space_endls) >> p << many(space_endls)
+    
+    star = string('*').parsecmap(lambda _: [elem_all])
+    multiple = sepBy(elem_parser, space_endls)
+    negate = (
+        (token_endl(regex('\*\s+\\\\').parsecmap(lambda _: elem_all)
+                    ^ (token_endl(elem_parser) << token_endl(string("\\")))) +
+         between(token_endl(string("{")), token_endl(string("}")), multiple))).parsecmap(
+             lambda (rng, dffs): DOC(rng, dffs))
+
+    return (token_endl(negate ^ star ^ multiple) )
+
+