Add /address-pair for address pair outside of owner network. (#3)

Add /address-pair method to allow cross-owner network address pair binding.

Reviewed-by: Rico Lin <ricolin@ricolky.com>
diff --git a/neutron_policy_server/tests/test_neutron_address_pair.py b/neutron_policy_server/tests/test_neutron_address_pair.py
index 99f49c0..f9d2e08 100644
--- a/neutron_policy_server/tests/test_neutron_address_pair.py
+++ b/neutron_policy_server/tests/test_neutron_address_pair.py
@@ -1,5 +1,7 @@
 # SPDX-License-Identifier: Apache-2.0
 
+from copy import deepcopy
+
 import pytest
 from neutron.tests.unit.db import test_allowedaddresspairs_db as base_test
 from neutron_lib.api.definitions import allowedaddresspairs as addr_apidef
@@ -106,6 +108,47 @@
         self.update_port["target"]["attributes_to_update"] = ["mac_address"]
         self.update_port["target"]["mac_address"] = "52:54:00:41:a4:97"
 
+        self.allowed_address_pairs = {
+            "rule": "allowed_address_pairs",
+            "target": self.port_dep["port"].copy(),
+            "credentials": {
+                "user_id": "fake_user",
+                "project_id": self.port_dep["port"]["project_id"],
+            },
+        }
+        self.allowed_address_pairs["target"]["attributes_to_update"] = [
+            "allowed_address_pairs"
+        ]
+        self.allowed_address_pairs["target"]["allowed_address_pairs"] = [
+            {"mac_address": "00:00:00:00:00:01", "ip_address": "10.0.0.1"}
+        ]
+        self.allowed_address_pairs_not_found = deepcopy(self.allowed_address_pairs)
+        self.allowed_address_pairs_not_found["target"]["allowed_address_pairs"] = [
+            {"ip_address": "10.96.250.203", "mac_address": "fa:16:3e:da:ed:0b"}
+        ]
+        self.allowed_address_pairs_address_not_found = deepcopy(
+            self.allowed_address_pairs
+        )
+        self.allowed_address_pairs_address_not_found["target"]["allowed_address_pairs"][
+            0
+        ]["ip_address"] = "10.96.250.203"
+        self.allowed_address_pairs_no_attribute = deepcopy(self.allowed_address_pairs)
+        self.allowed_address_pairs_no_attribute["target"]["attributes_to_update"] = []
+        self.allowed_address_pairs_not_in_attribute = deepcopy(
+            self.allowed_address_pairs
+        )
+        self.allowed_address_pairs_not_in_attribute["target"][
+            "attributes_to_update"
+        ] = ["mac_address"]
+
+        self.allowed_address_pairs_empty = deepcopy(self.allowed_address_pairs)
+        self.allowed_address_pairs_empty["target"]["allowed_address_pairs"] = []
+
+        self.allowed_address_pairs_target_not_found = deepcopy(
+            self.allowed_address_pairs
+        )
+        self.allowed_address_pairs_target_not_found["target"]["id"] = "foo"
+
     @pytest.fixture()
     def app(
         self,
@@ -126,14 +169,14 @@
 class TestAddressPairCasesFlask(TestAddressPairCasesFlaskBase):
 
     def test_port_delete_success(self):
-        response = self.client.post(
+        response = self.client.post(  # pylint: disable=E1101
             "/port-delete", json=self.delete_port_dep_json
         )  # pylint: disable=E1101
         self.assertEqual(b"True", response.data)
         self.assertEqual(200, response.status_code)
 
     def test_port_delete_fail_with_dep(self):
-        response = self.client.post(
+        response = self.client.post(  # pylint: disable=E1101
             "/port-delete", json=self.delete_port_json
         )  # pylint: disable=E1101
         self.assertEqual(
@@ -149,30 +192,30 @@
         self.assertEqual(403, response.status_code)
 
     def test_port_update_success(self):
-        response = self.client.post(
+        response = self.client.post(  # pylint: disable=E1101
             "/port-update", json=self.update_port_dep
         )  # pylint: disable=E1101
         self.assertEqual(b"True", response.data)
         self.assertEqual(200, response.status_code)
 
     def test_port_update_success_no_address_change(self):
-        response = self.client.post(
+        response = self.client.post(  # pylint: disable=E1101
             "/port-update", json=self.update_port_no_address
         )  # pylint: disable=E1101
         self.assertEqual(b"True", response.data)
         self.assertEqual(200, response.status_code)
 
     def test_port_update_fail_no_match_port(self):
-        response = self.client.post(
+        response = self.client.post(  # pylint: disable=E1101
             "/port-update", json=self.update_port_not_exist
         )  # pylint: disable=E1101
         self.assertEqual(b"True", response.data)
         self.assertEqual(200, response.status_code)
 
     def test_port_update_fail(self):
-        response = self.client.post(
+        response = self.client.post(  # pylint: disable=E1101
             "/port-update", json=self.update_port
-        )  # pylint: disable=E1101
+        )
         self.assertEqual(
             bytes(
                 (
@@ -188,3 +231,78 @@
     def test_health_check_success(self):
         response = self.client.get("/health")  # pylint: disable=E1101
         self.assertEqual(200, response.status_code)
+
+    def test_address_pair_success(self):
+        response = self.client.post(  # pylint: disable=E1101
+            "/address-pair", json=self.allowed_address_pairs
+        )
+        self.assertEqual(
+            b"True",
+            response.data,
+        )
+        self.assertEqual(200, response.status_code)
+
+    def test_address_pair_success_no_attributes_to_update(self):
+        response = self.client.post(  # pylint: disable=E1101
+            "/address-pair", json=self.allowed_address_pairs_no_attribute
+        )
+        self.assertEqual(
+            b"True",
+            response.data,
+        )
+        self.assertEqual(200, response.status_code)
+
+    def test_address_pair_success_empty(self):
+        response = self.client.post(  # pylint: disable=E1101
+            "/address-pair", json=self.allowed_address_pairs_empty
+        )
+        self.assertEqual(
+            b"True",
+            response.data,
+        )
+        self.assertEqual(200, response.status_code)
+
+    def test_address_pair_success_not_in_attributes(self):
+        response = self.client.post(  # pylint: disable=E1101
+            "/address-pair", json=self.allowed_address_pairs_not_in_attribute
+        )
+        self.assertEqual(
+            b"True",
+            response.data,
+        )
+        self.assertEqual(200, response.status_code)
+
+    def test_address_pair_fail_target_not_found(self):
+        response = self.client.post(  # pylint: disable=E1101
+            "/address-pair", json=self.allowed_address_pairs_target_not_found
+        )
+        portname = self.allowed_address_pairs_target_not_found["target"]["id"]
+        self.assertEqual(
+            f"Can't fetch port {portname} with current context, skip this check.".encode(
+                "utf-8"
+            ),
+            response.data,
+        )
+        self.assertEqual(403, response.status_code)
+
+    def test_address_pair_fail_mac_not_found(self):
+        response = self.client.post(  # pylint: disable=E1101
+            "/address-pair", json=self.allowed_address_pairs_not_found
+        )
+        self.assertEqual(
+            b"Zero or Multiple match port found with MAC address fa:16:3e:da:ed:0b.",
+            response.data,
+        )
+        self.assertEqual(403, response.status_code)
+
+    def test_address_pair_fail_address_not_found(self):
+        response = self.client.post(  # pylint: disable=E1101
+            "/address-pair", json=self.allowed_address_pairs_address_not_found
+        )
+        self.assertEqual(
+            f"IP address not exists in network from project {self.port['port']['project_id']}.".encode(
+                "utf-8"
+            ),
+            response.data,
+        )
+        self.assertEqual(403, response.status_code)
diff --git a/neutron_policy_server/wsgi.py b/neutron_policy_server/wsgi.py
index 0e2561c..dc27baa 100644
--- a/neutron_policy_server/wsgi.py
+++ b/neutron_policy_server/wsgi.py
@@ -6,7 +6,6 @@
 from flask import Flask, Response, g, request
 from neutron.common import config
 from neutron.db.models import allowed_address_pair as models
-from neutron.objects import network as network_obj
 from neutron.objects import ports as port_obj
 from neutron.objects.port.extensions import allowedaddresspairs as aap_obj
 from neutron_lib import context
@@ -46,46 +45,106 @@
     )
 
 
-# TODO(rlin): Only enable this after neutron bug/2069071 fixed.
-# @app.route("/address-pair", methods=["POST"])
+@app.route("/address-pair", methods=["POST"])
 def enforce_address_pair():
-    # TODO(mnaser): Validate this logic, ideally we should limit this policy
-    #               check only if its a provider network
-    with db_api.CONTEXT_READER.using(g.ctx):
-        network = network_obj.Network.get_object(g.ctx, id=g.target["network_id"])
-    if network["shared"] is False:
-        return Response("Not shared network", status=403, mimetype="text/plain")
+    if "attributes_to_update" not in g.target:
+        LOG.info("No attributes_to_update found, skip check.")
+        return Response("True", status=200, mimetype="text/plain")
+    elif "allowed_address_pairs" not in g.target["attributes_to_update"]:
+        LOG.info(
+            "No allowed_address_pairs in update targets "
+            f"for port {g.target['id']}, skip check."
+        )
+        return Response("True", status=200, mimetype="text/plain")
+    if g.target.get("allowed_address_pairs", []) == []:
+        LOG.info("Empty address pair to check on, skip check.")
+        return Response("True", status=200, mimetype="text/plain")
 
-    for allowed_address_pair in g.target.get("allowed_address_pairs", []):
-        with db_api.CONTEXT_READER.using(g.ctx):
-            ports = port_obj.Port.get_objects(
-                g.ctx,
-                network_id=g.target["network_id"],
-                project_id=g.target["project_id"],
-                mac_address=allowed_address_pair["mac_address"],
-            )
-        if len(ports) != 1:
+    # TODO(rlin): Ideally we should limit this policy check only if its a provider network
+
+    ports = port_obj.Port.get_objects(g.ctx, id=[g.target["id"]])
+    if len(ports) == 0:
+        # Note(ricolin): This happens with ports that are not well defined
+        # and missing context factors like project_id.
+        # Which port usually created by services and design for internal
+        # uses. We can skip this check and avoid blocking services.
+        msg = (
+            f"Can't fetch port {g.target['id']} with current "
+            "context, skip this check."
+        )
+        LOG.info(msg)
+        return Response(msg, status=403, mimetype="text/plain")
+
+    verify_address_pairs = []
+    target_port = ports[0]
+    db_pairs = (
+        target_port.allowed_address_pairs if target_port.allowed_address_pairs else []
+    )
+    target_pairs = g.target.get("allowed_address_pairs", [])
+    db_pairs_dict = {str(p.ip_address): str(p.mac_address) for p in db_pairs}
+    for pair in target_pairs:
+        if pair.get("ip_address") not in db_pairs_dict:
+            verify_address_pairs.append(pair)
+        elif pair.get("mac_address") and db_pairs_dict[
+            pair.get("ip_address")
+        ] != pair.get("mac_address"):
+            verify_address_pairs.append(pair)
+
+    for allowed_address_pair in verify_address_pairs:
+        if "mac_address" in allowed_address_pair:
+            with db_api.CONTEXT_READER.using(g.ctx):
+                ports = port_obj.Port.get_objects(
+                    g.ctx,
+                    network_id=g.target["network_id"],
+                    project_id=g.target["project_id"],
+                    mac_address=allowed_address_pair["mac_address"],
+                )
+            if len(ports) != 1:
+                msg = (
+                    "Zero or Multiple match port found with "
+                    f"MAC address {allowed_address_pair['mac_address']}."
+                )
+                LOG.info(f"{msg} Fail check.")
+                return Response(msg, status=403, mimetype="text/plain")
+        else:
+            with db_api.CONTEXT_READER.using(g.ctx):
+                ports = port_obj.Port.get_objects(
+                    g.ctx,
+                    network_id=g.target["network_id"],
+                    project_id=g.target["project_id"],
+                )
+        if "ip_address" in allowed_address_pair:
+            found_match = False
+            for port in ports:
+                fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in port.fixed_ips]
+                if allowed_address_pair["ip_address"] in fixed_ips:
+                    found_match = True
+                    break
+            if found_match:
+                LOG.debug("Valid address pair.")
+                continue
+            msg = f"IP address not exists in network from project {g.target['project_id']}."
+            LOG.info(f"{msg} Fail check.")
             return Response(
-                "Zero or Multiple match port found.", status=403, mimetype="text/plain"
+                msg,
+                status=403,
+                mimetype="text/plain",
             )
-        fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in ports[0].fixed_ips]
-        if allowed_address_pair["ip_address"] not in fixed_ips:
-            return Response(
-                "IP address not exists in ports.", status=403, mimetype="text/plain"
-            )
+    LOG.info("Valid port for address pairs, passed check.")
     return Response("True", status=200, mimetype="text/plain")
 
 
 @app.route("/port-update", methods=["POST"])
 def enforce_port_update():
-    if (
-        "attributes_to_update" in g.target
-        and ("mac_address" not in g.target["attributes_to_update"])
-        and ("fixed_ips" not in g.target["attributes_to_update"])
+    if "attributes_to_update" not in g.target:
+        LOG.info("No attributes_to_update found, skip check.")
+        return Response("True", status=200, mimetype="text/plain")
+    elif ("mac_address" not in g.target["attributes_to_update"]) and (
+        "fixed_ips" not in g.target["attributes_to_update"]
     ):
         LOG.info(
-            "No mac_address or fixed_ips in update targets for port "
-            f"{g.target['id']}, skip check."
+            "No mac_address or fixed_ips in update targets "
+            f"for port {g.target['id']}, skip check."
         )
         return Response("True", status=200, mimetype="text/plain")