Add /address-pair for address pair outside of owner network. (#3)
Add /address-pair method to allow cross-owner network address pair binding.
Reviewed-by: Rico Lin <ricolin@ricolky.com>
diff --git a/neutron_policy_server/tests/test_neutron_address_pair.py b/neutron_policy_server/tests/test_neutron_address_pair.py
index 99f49c0..f9d2e08 100644
--- a/neutron_policy_server/tests/test_neutron_address_pair.py
+++ b/neutron_policy_server/tests/test_neutron_address_pair.py
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
+from copy import deepcopy
+
import pytest
from neutron.tests.unit.db import test_allowedaddresspairs_db as base_test
from neutron_lib.api.definitions import allowedaddresspairs as addr_apidef
@@ -106,6 +108,47 @@
self.update_port["target"]["attributes_to_update"] = ["mac_address"]
self.update_port["target"]["mac_address"] = "52:54:00:41:a4:97"
+ self.allowed_address_pairs = {
+ "rule": "allowed_address_pairs",
+ "target": self.port_dep["port"].copy(),
+ "credentials": {
+ "user_id": "fake_user",
+ "project_id": self.port_dep["port"]["project_id"],
+ },
+ }
+ self.allowed_address_pairs["target"]["attributes_to_update"] = [
+ "allowed_address_pairs"
+ ]
+ self.allowed_address_pairs["target"]["allowed_address_pairs"] = [
+ {"mac_address": "00:00:00:00:00:01", "ip_address": "10.0.0.1"}
+ ]
+ self.allowed_address_pairs_not_found = deepcopy(self.allowed_address_pairs)
+ self.allowed_address_pairs_not_found["target"]["allowed_address_pairs"] = [
+ {"ip_address": "10.96.250.203", "mac_address": "fa:16:3e:da:ed:0b"}
+ ]
+ self.allowed_address_pairs_address_not_found = deepcopy(
+ self.allowed_address_pairs
+ )
+ self.allowed_address_pairs_address_not_found["target"]["allowed_address_pairs"][
+ 0
+ ]["ip_address"] = "10.96.250.203"
+ self.allowed_address_pairs_no_attribute = deepcopy(self.allowed_address_pairs)
+ self.allowed_address_pairs_no_attribute["target"]["attributes_to_update"] = []
+ self.allowed_address_pairs_not_in_attribute = deepcopy(
+ self.allowed_address_pairs
+ )
+ self.allowed_address_pairs_not_in_attribute["target"][
+ "attributes_to_update"
+ ] = ["mac_address"]
+
+ self.allowed_address_pairs_empty = deepcopy(self.allowed_address_pairs)
+ self.allowed_address_pairs_empty["target"]["allowed_address_pairs"] = []
+
+ self.allowed_address_pairs_target_not_found = deepcopy(
+ self.allowed_address_pairs
+ )
+ self.allowed_address_pairs_target_not_found["target"]["id"] = "foo"
+
@pytest.fixture()
def app(
self,
@@ -126,14 +169,14 @@
class TestAddressPairCasesFlask(TestAddressPairCasesFlaskBase):
def test_port_delete_success(self):
- response = self.client.post(
+ response = self.client.post( # pylint: disable=E1101
"/port-delete", json=self.delete_port_dep_json
) # pylint: disable=E1101
self.assertEqual(b"True", response.data)
self.assertEqual(200, response.status_code)
def test_port_delete_fail_with_dep(self):
- response = self.client.post(
+ response = self.client.post( # pylint: disable=E1101
"/port-delete", json=self.delete_port_json
) # pylint: disable=E1101
self.assertEqual(
@@ -149,30 +192,30 @@
self.assertEqual(403, response.status_code)
def test_port_update_success(self):
- response = self.client.post(
+ response = self.client.post( # pylint: disable=E1101
"/port-update", json=self.update_port_dep
) # pylint: disable=E1101
self.assertEqual(b"True", response.data)
self.assertEqual(200, response.status_code)
def test_port_update_success_no_address_change(self):
- response = self.client.post(
+ response = self.client.post( # pylint: disable=E1101
"/port-update", json=self.update_port_no_address
) # pylint: disable=E1101
self.assertEqual(b"True", response.data)
self.assertEqual(200, response.status_code)
def test_port_update_fail_no_match_port(self):
- response = self.client.post(
+ response = self.client.post( # pylint: disable=E1101
"/port-update", json=self.update_port_not_exist
) # pylint: disable=E1101
self.assertEqual(b"True", response.data)
self.assertEqual(200, response.status_code)
def test_port_update_fail(self):
- response = self.client.post(
+ response = self.client.post( # pylint: disable=E1101
"/port-update", json=self.update_port
- ) # pylint: disable=E1101
+ )
self.assertEqual(
bytes(
(
@@ -188,3 +231,78 @@
def test_health_check_success(self):
response = self.client.get("/health") # pylint: disable=E1101
self.assertEqual(200, response.status_code)
+
+ def test_address_pair_success(self):
+ response = self.client.post( # pylint: disable=E1101
+ "/address-pair", json=self.allowed_address_pairs
+ )
+ self.assertEqual(
+ b"True",
+ response.data,
+ )
+ self.assertEqual(200, response.status_code)
+
+ def test_address_pair_success_no_attributes_to_update(self):
+ response = self.client.post( # pylint: disable=E1101
+ "/address-pair", json=self.allowed_address_pairs_no_attribute
+ )
+ self.assertEqual(
+ b"True",
+ response.data,
+ )
+ self.assertEqual(200, response.status_code)
+
+ def test_address_pair_success_empty(self):
+ response = self.client.post( # pylint: disable=E1101
+ "/address-pair", json=self.allowed_address_pairs_empty
+ )
+ self.assertEqual(
+ b"True",
+ response.data,
+ )
+ self.assertEqual(200, response.status_code)
+
+ def test_address_pair_success_not_in_attributes(self):
+ response = self.client.post( # pylint: disable=E1101
+ "/address-pair", json=self.allowed_address_pairs_not_in_attribute
+ )
+ self.assertEqual(
+ b"True",
+ response.data,
+ )
+ self.assertEqual(200, response.status_code)
+
+ def test_address_pair_fail_target_not_found(self):
+ response = self.client.post( # pylint: disable=E1101
+ "/address-pair", json=self.allowed_address_pairs_target_not_found
+ )
+ portname = self.allowed_address_pairs_target_not_found["target"]["id"]
+ self.assertEqual(
+ f"Can't fetch port {portname} with current context, skip this check.".encode(
+ "utf-8"
+ ),
+ response.data,
+ )
+ self.assertEqual(403, response.status_code)
+
+ def test_address_pair_fail_mac_not_found(self):
+ response = self.client.post( # pylint: disable=E1101
+ "/address-pair", json=self.allowed_address_pairs_not_found
+ )
+ self.assertEqual(
+ b"Zero or Multiple match port found with MAC address fa:16:3e:da:ed:0b.",
+ response.data,
+ )
+ self.assertEqual(403, response.status_code)
+
+ def test_address_pair_fail_address_not_found(self):
+ response = self.client.post( # pylint: disable=E1101
+ "/address-pair", json=self.allowed_address_pairs_address_not_found
+ )
+ self.assertEqual(
+ f"IP address not exists in network from project {self.port['port']['project_id']}.".encode(
+ "utf-8"
+ ),
+ response.data,
+ )
+ self.assertEqual(403, response.status_code)
diff --git a/neutron_policy_server/wsgi.py b/neutron_policy_server/wsgi.py
index 0e2561c..dc27baa 100644
--- a/neutron_policy_server/wsgi.py
+++ b/neutron_policy_server/wsgi.py
@@ -6,7 +6,6 @@
from flask import Flask, Response, g, request
from neutron.common import config
from neutron.db.models import allowed_address_pair as models
-from neutron.objects import network as network_obj
from neutron.objects import ports as port_obj
from neutron.objects.port.extensions import allowedaddresspairs as aap_obj
from neutron_lib import context
@@ -46,46 +45,106 @@
)
-# TODO(rlin): Only enable this after neutron bug/2069071 fixed.
-# @app.route("/address-pair", methods=["POST"])
+@app.route("/address-pair", methods=["POST"])
def enforce_address_pair():
- # TODO(mnaser): Validate this logic, ideally we should limit this policy
- # check only if its a provider network
- with db_api.CONTEXT_READER.using(g.ctx):
- network = network_obj.Network.get_object(g.ctx, id=g.target["network_id"])
- if network["shared"] is False:
- return Response("Not shared network", status=403, mimetype="text/plain")
+ if "attributes_to_update" not in g.target:
+ LOG.info("No attributes_to_update found, skip check.")
+ return Response("True", status=200, mimetype="text/plain")
+ elif "allowed_address_pairs" not in g.target["attributes_to_update"]:
+ LOG.info(
+ "No allowed_address_pairs in update targets "
+ f"for port {g.target['id']}, skip check."
+ )
+ return Response("True", status=200, mimetype="text/plain")
+ if g.target.get("allowed_address_pairs", []) == []:
+ LOG.info("Empty address pair to check on, skip check.")
+ return Response("True", status=200, mimetype="text/plain")
- for allowed_address_pair in g.target.get("allowed_address_pairs", []):
- with db_api.CONTEXT_READER.using(g.ctx):
- ports = port_obj.Port.get_objects(
- g.ctx,
- network_id=g.target["network_id"],
- project_id=g.target["project_id"],
- mac_address=allowed_address_pair["mac_address"],
- )
- if len(ports) != 1:
+ # TODO(rlin): Ideally we should limit this policy check only if its a provider network
+
+ ports = port_obj.Port.get_objects(g.ctx, id=[g.target["id"]])
+ if len(ports) == 0:
+ # Note(ricolin): This happens with ports that are not well defined
+ # and missing context factors like project_id.
+ # Which port usually created by services and design for internal
+ # uses. We can skip this check and avoid blocking services.
+ msg = (
+ f"Can't fetch port {g.target['id']} with current "
+ "context, skip this check."
+ )
+ LOG.info(msg)
+ return Response(msg, status=403, mimetype="text/plain")
+
+ verify_address_pairs = []
+ target_port = ports[0]
+ db_pairs = (
+ target_port.allowed_address_pairs if target_port.allowed_address_pairs else []
+ )
+ target_pairs = g.target.get("allowed_address_pairs", [])
+ db_pairs_dict = {str(p.ip_address): str(p.mac_address) for p in db_pairs}
+ for pair in target_pairs:
+ if pair.get("ip_address") not in db_pairs_dict:
+ verify_address_pairs.append(pair)
+ elif pair.get("mac_address") and db_pairs_dict[
+ pair.get("ip_address")
+ ] != pair.get("mac_address"):
+ verify_address_pairs.append(pair)
+
+ for allowed_address_pair in verify_address_pairs:
+ if "mac_address" in allowed_address_pair:
+ with db_api.CONTEXT_READER.using(g.ctx):
+ ports = port_obj.Port.get_objects(
+ g.ctx,
+ network_id=g.target["network_id"],
+ project_id=g.target["project_id"],
+ mac_address=allowed_address_pair["mac_address"],
+ )
+ if len(ports) != 1:
+ msg = (
+ "Zero or Multiple match port found with "
+ f"MAC address {allowed_address_pair['mac_address']}."
+ )
+ LOG.info(f"{msg} Fail check.")
+ return Response(msg, status=403, mimetype="text/plain")
+ else:
+ with db_api.CONTEXT_READER.using(g.ctx):
+ ports = port_obj.Port.get_objects(
+ g.ctx,
+ network_id=g.target["network_id"],
+ project_id=g.target["project_id"],
+ )
+ if "ip_address" in allowed_address_pair:
+ found_match = False
+ for port in ports:
+ fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in port.fixed_ips]
+ if allowed_address_pair["ip_address"] in fixed_ips:
+ found_match = True
+ break
+ if found_match:
+ LOG.debug("Valid address pair.")
+ continue
+ msg = f"IP address not exists in network from project {g.target['project_id']}."
+ LOG.info(f"{msg} Fail check.")
return Response(
- "Zero or Multiple match port found.", status=403, mimetype="text/plain"
+ msg,
+ status=403,
+ mimetype="text/plain",
)
- fixed_ips = [str(fixed_ip["ip_address"]) for fixed_ip in ports[0].fixed_ips]
- if allowed_address_pair["ip_address"] not in fixed_ips:
- return Response(
- "IP address not exists in ports.", status=403, mimetype="text/plain"
- )
+ LOG.info("Valid port for address pairs, passed check.")
return Response("True", status=200, mimetype="text/plain")
@app.route("/port-update", methods=["POST"])
def enforce_port_update():
- if (
- "attributes_to_update" in g.target
- and ("mac_address" not in g.target["attributes_to_update"])
- and ("fixed_ips" not in g.target["attributes_to_update"])
+ if "attributes_to_update" not in g.target:
+ LOG.info("No attributes_to_update found, skip check.")
+ return Response("True", status=200, mimetype="text/plain")
+ elif ("mac_address" not in g.target["attributes_to_update"]) and (
+ "fixed_ips" not in g.target["attributes_to_update"]
):
LOG.info(
- "No mac_address or fixed_ips in update targets for port "
- f"{g.target['id']}, skip check."
+ "No mac_address or fixed_ips in update targets "
+ f"for port {g.target['id']}, skip check."
)
return Response("True", status=200, mimetype="text/plain")