feat: add port update and delete check (#1)



Reviewed-by: Mohammed Naser <mnaser@vexxhost.com>
Reviewed-by: Rico Lin <ricolin@ricolky.com>
diff --git a/.gitignore b/.gitignore
index e69de29..44121a5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -0,0 +1,17 @@
+*.pyc
+cover/
+.coverage*
+!.coveragerc
+.tox
+nosetests.xml
+.testrepository
+.stestr
+.venv
+
+# Packages
+*.egg*
+*.egg-info
+
+# pbr generates these
+AUTHORS
+ChangeLog
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..8324d4a
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,20 @@
+repos:
+  - repo: https://github.com/pre-commit/pre-commit-hooks
+    rev: v4.3.0
+    hooks:
+      - id: end-of-file-fixer
+      - id: trailing-whitespace
+
+  - repo: https://github.com/psf/black
+    rev: 24.4.0
+    hooks:
+      - id: black
+
+  - repo: https://github.com/pycqa/flake8
+    rev: 7.0.0
+    hooks:
+      - id: flake8
+  - repo: https://github.com/pycqa/isort
+    rev: 5.13.2
+    hooks:
+      - id: isort
diff --git a/.zuul.yaml b/.zuul.yaml
new file mode 100644
index 0000000..66b259d
--- /dev/null
+++ b/.zuul.yaml
@@ -0,0 +1,48 @@
+- secret:
+    name: neutron-policy-server-pypi
+    data:
+      api_token: !encrypted/pkcs1-oaep
+        - eRyk66+lyVIomDFkugHPJSSlTF/WIH1fadNm+DHpIVpz4j50ow2sNJoOivBHRCE68Pc28
+          w+HFbLa+pYuFCX8ErZU6KnlenruA8om8yprMh+gNoe+mFs/QkZF4sYbSTox1QmP23DhXq
+          FhcNUk3rZdb3m0YIMU5Ti5UdmOG2MraNTO99QrZ9Qw8nuvbqcKfJgvEsK2IwB+0ZIBZpG
+          5+mOM7IzVdXyuBQ9BG1Q8ezTB2zGGi3RfD6ImjRzL2iHlJ/aIeh5R4kmzt6e7LEWPAxem
+          qxHfkSCc7nwnPPju9Uk9aL+1wXAaxqkYKCdwlVubgRzCCC301nnr12eBnksZ7/RT3du87
+          MAs+RSpRvXV4vRcxvwBfCN651i5dFCCUG4Gk0HgyuSh+Ud4wyWTpWTG/bXzeM7blt3Vrr
+          sF8hYxJSFXQGYHrpaZANzzlQMbxgbVtijLWwLxMoraR83jSPeNqR4kiR6DzuQQhbnkyz1
+          QqDMPtrIhPUne0J3poPAzGIQlIy3Wz5yElJXLlSUNPTY+YelA4X98l5g+arplok6Jkl3D
+          4F/r0d+xISfYf03+I8xbZPgd8Q43TJPqu0dZLFH5p8IctQvuJ5Os3CQ3ehy8M+VGwmoLC
+          D5bHF2XF/E0LKLehT3T2v1B6weoKJY3C5rGTzekRjJ5UksdJXH3a8l7RVI1uLY=
+
+- job:
+    name: neutron-policy-server-build
+    parent: build-python-release
+    vars:
+      release_python: python3
+
+- job:
+    name: neutron-policy-server-release
+    parent: python-upload-pypi
+    vars:
+      release_python: python3
+    secrets:
+      - secret: neutron-policy-server-pypi
+        name: pypi_info
+        pass-to-parent: true
+
+- project:
+    merge-mode: squash-merge
+    check:
+      jobs:
+        - neutron-policy-server-build
+        - tox-linters
+        - tox-py310
+        - tox-py311
+    gate:
+      jobs:
+        - neutron-policy-server-build
+        - tox-linters
+        - tox-py310
+        - tox-py311
+    release:
+      jobs:
+        - neutron-policy-server-release
diff --git a/neutron_policy_server/tests/__init__.py b/neutron_policy_server/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/neutron_policy_server/tests/__init__.py
diff --git a/neutron_policy_server/tests/test.py b/neutron_policy_server/tests/test.py
new file mode 100644
index 0000000..73de2d1
--- /dev/null
+++ b/neutron_policy_server/tests/test.py
@@ -0,0 +1,13 @@
+# SPDX-License-Identifier: Apache-2.0
+
+"""Base classes for our unit tests."""
+from unittest import mock
+
+import testtools
+
+
+@mock.patch("neutron.common.config")
+class TestCase(testtools.TestCase):
+    """Test case base class for all unit tests."""
+
+    pass
diff --git a/neutron_policy_server/tests/test_neutron_address_pair.py b/neutron_policy_server/tests/test_neutron_address_pair.py
new file mode 100644
index 0000000..80c734c
--- /dev/null
+++ b/neutron_policy_server/tests/test_neutron_address_pair.py
@@ -0,0 +1,176 @@
+# SPDX-License-Identifier: Apache-2.0
+
+import pytest
+from neutron.tests.unit.db import test_allowedaddresspairs_db as base_test
+from neutron_lib.api.definitions import allowedaddresspairs as addr_apidef
+from oslo_config import cfg
+from oslo_db import options as db_options
+
+from neutron_policy_server import wsgi
+from neutron_policy_server.tests import test
+
+
+class TestAddressPairCasesFlaskBase(
+    test.TestCase, base_test.AllowedAddressPairDBTestCase
+):
+
+    def setUp(self, plugin=None, ext_mgr=None):
+        super(TestAddressPairCasesFlaskBase, self).setUp(plugin)
+        address_pairs = [{"mac_address": "00:00:00:00:00:01", "ip_address": "10.0.0.1"}]
+        db_options.set_defaults(cfg.CONF, connection="sqlite://")
+
+        with self.network() as net:
+            with self.subnet(network=net, cidr="10.0.0.0/24") as subnet:
+                fixed_ips = [
+                    {"subnet_id": subnet["subnet"]["id"], "ip_address": "10.0.0.1"}
+                ]
+
+            self.port_resp = self._create_port(
+                self.fmt,
+                net["network"]["id"],
+                mac_address="00:00:00:00:00:01",
+                fixed_ips=fixed_ips,
+            )
+            self.port = self.deserialize(self.fmt, self.port_resp)
+
+            self.port_resp_dep = self._create_port(
+                self.fmt,
+                net["network"]["id"],
+                arg_list=(addr_apidef.ADDRESS_PAIRS,),
+                allowed_address_pairs=address_pairs,
+            )
+            self.port_dep = self.deserialize(self.fmt, self.port_resp_dep)
+
+        # delete
+        self.delete_port_json = {
+            "rule": "delete_port",
+            "target": self.port["port"],
+            "credentials": {
+                "user_id": "fake_user",
+                "project_id": self.port["port"]["project_id"],
+            },
+        }
+        self.delete_port_dep_json = {
+            "rule": "delete_port",
+            "target": self.port_dep["port"],
+            "credentials": {
+                "user_id": "fake_user",
+                "project_id": self.port_dep["port"]["project_id"],
+            },
+        }
+
+        # update
+        self.update_port_no_address = {
+            "rule": "update_port",
+            "target": self.port["port"].copy(),
+            "credentials": {
+                "user_id": "fake_user",
+                "project_id": self.port_dep["port"]["project_id"],
+            },
+        }
+        self.update_port_no_address["target"]["attributes_to_update"] = ["name"]
+        self.update_port_no_address["target"]["name"] = "new_name"
+
+        self.update_port_not_exist = {
+            "rule": "update_port",
+            "target": self.port_dep["port"].copy(),
+            "credentials": {
+                "user_id": "fake_user",
+                "project_id": self.port_dep["port"]["project_id"],
+            },
+        }
+        self.update_port_not_exist["target"]["attributes_to_update"] = ["mac_address"]
+        self.update_port_not_exist["target"][
+            "id"
+        ] = "52c5a95c-9310-4993-a731-89cfd5a41fd9"
+
+        self.update_port_dep = {
+            "rule": "update_port",
+            "target": self.port_dep["port"].copy(),
+            "credentials": {
+                "user_id": "fake_user",
+                "project_id": self.port_dep["port"]["project_id"],
+            },
+        }
+        self.update_port_dep["target"]["attributes_to_update"] = ["mac_address"]
+        self.update_port_dep["target"]["mac_address"] = "52:54:00:41:a4:97"
+
+        self.update_port = {
+            "rule": "update_port",
+            "target": self.port["port"].copy(),
+            "credentials": {
+                "user_id": "fake_user",
+                "project_id": self.port_dep["port"]["project_id"],
+            },
+        }
+        self.update_port["target"]["attributes_to_update"] = ["mac_address"]
+        self.update_port["target"]["mac_address"] = "52:54:00:41:a4:97"
+
+    @pytest.fixture()
+    def app(
+        self,
+    ):
+        app = wsgi.create_app()
+        yield app
+
+    @pytest.fixture()
+    def client(self, app):
+        return app.test_client()
+
+    @pytest.fixture()
+    def runner(self, app):
+        return app.test_cli_runner()
+
+
+@pytest.mark.usefixtures("client_class")
+class TestAddressPairCasesFlask(TestAddressPairCasesFlaskBase):
+
+    def test_port_delete_success(self):
+        response = self.client.post(
+            "/port-delete", json=self.delete_port_dep_json
+        )  # pylint: disable=E1101
+        self.assertEqual(b"True", response.data)
+        self.assertEqual(200, response.status_code)
+
+    def test_port_delete_fail_with_dep(self):
+        response = self.client.post(
+            "/port-delete", json=self.delete_port_json
+        )  # pylint: disable=E1101
+        self.assertEqual(
+            b"Address pairs dependency found for this port.", response.data
+        )
+        self.assertEqual(403, response.status_code)
+
+    def test_port_update_success(self):
+        response = self.client.post(
+            "/port-update", json=self.update_port_dep
+        )  # pylint: disable=E1101
+        self.assertEqual(b"True", response.data)
+        self.assertEqual(200, response.status_code)
+
+    def test_port_update_success_no_address_change(self):
+        response = self.client.post(
+            "/port-update", json=self.update_port_no_address
+        )  # pylint: disable=E1101
+        self.assertEqual(b"True", response.data)
+        self.assertEqual(200, response.status_code)
+
+    def test_port_update_fail_no_match_port(self):
+        response = self.client.post(
+            "/port-update", json=self.update_port_not_exist
+        )  # pylint: disable=E1101
+        self.assertEqual(b"No match port found.", response.data)
+        self.assertEqual(403, response.status_code)
+
+    def test_port_update_fail(self):
+        response = self.client.post(
+            "/port-update", json=self.update_port
+        )  # pylint: disable=E1101
+        self.assertEqual(
+            b"Address pairs dependency found for this port.", response.data
+        )
+        self.assertEqual(403, response.status_code)
+
+    def test_health_check_success(self):
+        response = self.client.get("/health")  # pylint: disable=E1101
+        self.assertEqual(200, response.status_code)
diff --git a/neutron_policy_server/wsgi.py b/neutron_policy_server/wsgi.py
index 1b33bbb..4ac5256 100644
--- a/neutron_policy_server/wsgi.py
+++ b/neutron_policy_server/wsgi.py
@@ -1,11 +1,14 @@
 # SPDX-License-Identifier: Apache-2.0
 
+import json
 import sys
 
-from flask import Flask, Response, request
+from flask import Flask, Response, g, request
 from neutron.common import config
+from neutron.db.models import allowed_address_pair as models
 from neutron.objects import network as network_obj
 from neutron.objects import ports as port_obj
+from neutron.objects.port.extensions import allowedaddresspairs as aap_obj
 from neutron_lib import context
 from neutron_lib.db import api as db_api
 
@@ -16,39 +19,129 @@
 app = Flask(__name__)
 
 
-@app.route("/enforce", methods=["POST"])
-def enforce():
-    data = request.json
-    rule = data.get("rule")
-    target = data.get("target")
-    creds = data.get("creds")
+@app.before_request
+def fetch_context():
+    # Skip detail data fetch if we're running health check
+    if request.path == "/health":
+        g.ctx = context.Context()
+        return
+    content_type = request.headers.get(
+        "Content-Type", "application/x-www-form-urlencoded"
+    )
+    if content_type == "application/x-www-form-urlencoded":
+        data = request.form.to_dict()
+        g.target = json.loads(data.get("target"))
+        g.creds = json.loads(data.get("credentials"))
+        g.rule = json.loads(data.get("rule"))
+    elif content_type == "application/json":
+        data = request.json
+        g.target = data.get("target")
+        g.creds = data.get("credentials")
+        g.rule = data.get("rule")
+    g.ctx = context.Context(
+        user_id=g.creds["user_id"], project_id=g.creds["project_id"]
+    )
 
-    ctx = context.Context(user_id=creds["user_id"], project_id=creds["project_id"])
 
-    if rule == "create_port:allowed_address_pairs":
-        # TODO(mnaser): Validate this logic, ideally we should limit this policy
-        #               check only if its a provider network
-        with db_api.CONTEXT_READER.using(ctx):
-            network = network_obj.Network.get_object(ctx, id=target["network_id"])
-        if network["shared"] is False:
-            return Response(status=403)
+# TODO(rlin): Only enable this after neutron bug/2069071 fixed.
+# @app.route("/address-pair", methods=["POST"])
+def enforce_address_pair():
+    # TODO(mnaser): Validate this logic, ideally we should limit this policy
+    #               check only if its a provider network
+    with db_api.CONTEXT_READER.using(g.ctx):
+        network = network_obj.Network.get_object(g.ctx, id=g.target["network_id"])
+    if network["shared"] is False:
+        return Response("Not shared network", status=403, mimetype="text/plain")
 
-        for allowed_address_pair in target.get("allowed_address_pairs", []):
-            with db_api.CONTEXT_READER.using(ctx):
-                ports = port_obj.Port.get_objects(
-                    ctx,
-                    network_id=target["network_id"],
-                    project_id=target["project_id"],
-                    mac_address=allowed_address_pair["mac_address"],
+    for allowed_address_pair in g.target.get("allowed_address_pairs", []):
+        with db_api.CONTEXT_READER.using(g.ctx):
+            ports = port_obj.Port.get_objects(
+                g.ctx,
+                network_id=g.target["network_id"],
+                project_id=g.target["project_id"],
+                mac_address=allowed_address_pair["mac_address"],
+            )
+        if len(ports) != 1:
+            return Response(
+                "Zero or Multiple match port found.", status=403, mimetype="text/plain"
+            )
+        fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in ports[0].fixed_ips]
+        if allowed_address_pair["ip_address"] not in fixed_ips:
+            return Response(
+                "IP address not exists in ports.", status=403, mimetype="text/plain"
+            )
+    return Response("True", status=200, mimetype="text/plain")
+
+
+@app.route("/port-update", methods=["POST"])
+def enforce_port_update():
+    if (
+        "attributes_to_update" in g.target
+        and ("mac_address" not in g.target["attributes_to_update"])
+        and ("fixed_ips" not in g.target["attributes_to_update"])
+    ):
+        return Response("True", status=200, mimetype="text/plain")
+
+    with db_api.CONTEXT_READER.using(g.ctx):
+        ports = port_obj.Port.get_objects(g.ctx, id=g.target["id"])
+        if len(ports) == 0:
+            return Response("No match port found.", status=403, mimetype="text/plain")
+
+        fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in ports[0].fixed_ips]
+
+        query = (
+            g.ctx.session.query(models.AllowedAddressPair)
+            .filter(
+                models.AllowedAddressPair.mac_address.in_([str(ports[0].mac_address)])
+            )
+            .filter(models.AllowedAddressPair.ip_address.in_(fixed_ips))
+        )
+        pairs = query.all()
+    pairs = [
+        aap_obj.AllowedAddressPair._load_object(context, db_obj)
+        for db_obj in query.all()
+    ]
+    if len(pairs) > 0:
+        return Response(
+            "Address pairs dependency found for this port.",
+            status=403,
+            mimetype="text/plain",
+        )
+    return Response("True", status=200, mimetype="text/plain")
+
+
+@app.route("/port-delete", methods=["POST"])
+def enforce_port_delete():
+    fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in g.target["fixed_ips"]]
+    with db_api.CONTEXT_READER.using(g.ctx):
+        query = (
+            g.ctx.session.query(models.AllowedAddressPair)
+            .filter(
+                models.AllowedAddressPair.mac_address.in_(
+                    [str(g.target["mac_address"])]
                 )
+            )
+            .filter(models.AllowedAddressPair.ip_address.in_(fixed_ips))
+        )
 
-            if len(ports) != 1:
-                return Response(status=403)
+    pairs = query.all()
+    pairs = [
+        aap_obj.AllowedAddressPair._load_object(context, db_obj)
+        for db_obj in query.all()
+    ]
+    if len(pairs) > 0:
+        return Response(
+            "Address pairs dependency found for this port.",
+            status=403,
+            mimetype="text/plain",
+        )
+    return Response("True", status=200, mimetype="text/plain")
 
-            fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in ports[0].fixed_ips]
-            if allowed_address_pair["ip_address"] not in fixed_ips:
-                return Response(status=403)
 
+@app.route("/health", methods=["GET"])
+def health_check():
+    with db_api.CONTEXT_READER.using(g.ctx):
+        port_obj.Port.get_objects(g.ctx, id="neutron_policy_server_health_check")
         return Response(status=200)
 
 
@@ -57,4 +150,4 @@
 
 
 if __name__ == "__main__":
-    create_app().run(host="0.0.0.0", port=8080)
+    create_app().run(host="0.0.0.0", port=9697)
diff --git a/setup.cfg b/setup.cfg
index a44cc3b..d912812 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -10,7 +10,7 @@
     Bug Tracker = https://github.com/vexxhost/neutron-policy-server/issues
     Documentation = https://vexxhost.github.io/neutron-policy-server/
     Source Code = https://github.com/vexxhost/neutron-policy-server
-python_requires = >=3.10
+python_requires = >=3.8
 classifiers =
     Development Status :: 5 - Production/Stable
     Environment :: OpenStack
diff --git a/test-requirements.txt b/test-requirements.txt
index 1c83fde..a5c7288 100644
--- a/test-requirements.txt
+++ b/test-requirements.txt
@@ -1 +1,6 @@
-PyMySQL
+fixtures>=3.0.0 # Apache-2.0/BSD
+WebTest>=2.0.27 # MIT
+oslotest>=3.2.0 # Apache-2.0
+pytest
+pytest-mock
+pytest-flask
diff --git a/tox.ini b/tox.ini
index b67293f..d6c6e57 100644
--- a/tox.ini
+++ b/tox.ini
@@ -3,6 +3,9 @@
 
 [testenv]
 usedevelop = True
+setenv =
+  VIRTUAL_ENV={envdir}
+
 deps =
   -r{toxinidir}/test-requirements.txt
 
@@ -11,3 +14,17 @@
   {[testenv]deps}
 commands =
   {posargs}
+
+[testenv:linters]
+skipsdist = True
+deps = pre-commit
+commands =
+  pre-commit run --all-files --show-diff-on-failure
+
+[testenv:py{3,310,311,312}]
+changedir = neutron_policy_server/tests
+# change pytest tempdir and add posargs from command line
+commands = pytest {posargs}
+
+[pytest]
+pythonpath = neutron_policy_server