# Copyright 2015 Mellanox Technologies, Ltd
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import functools


from neutron.objects.qos import policy as policy_object
from neutron_lib.api.definitions import extra_dhcp_opt as edo_ext
from neutron_lib.api.definitions import portbindings
from neutron_lib import constants as neutron_const
from neutron_lib.db import api as db_api
from neutron_lib.plugins.ml2 import api
from oslo_config import cfg
from oslo_log import log

from networking_mlnx.journal import cleanup
from networking_mlnx.journal import journal
from networking_mlnx.journal import maintenance
from networking_mlnx.plugins.ml2.drivers.sdn import config
from networking_mlnx.plugins.ml2.drivers.sdn import constants as sdn_const
from networking_mlnx.plugins.ml2.drivers.sdn import exceptions as sdn_excpt

LOG = log.getLogger(__name__)
cfg.CONF.register_opts(config.sdn_opts, sdn_const.GROUP_OPT)

NETWORK_QOS_POLICY = 'network_qos_policy'
DHCP_OPT_CLIENT_ID_NUM = '61'


def context_validator(context_type=None):
    def real_decorator(func):
        @functools.wraps(func)
        def wrapper(instance, context, *args, **kwargs):
            if context_type == sdn_const.PORT:
                # port context contain network_context
                # which include the segments
                segments = getattr(context.network, "network_segments", None)
            elif context_type == sdn_const.NETWORK:
                segments = getattr(context, "network_segments", None)
            else:
                segments = getattr(context, "segments_to_bind", None)
            if segments and getattr(instance, "check_segments", None):
                if instance.check_segments(segments):
                    return func(instance, context, *args, **kwargs)
        return wrapper
    return real_decorator


def error_handler(func):
    @functools.wraps(func)
    def wrapper(instance, *args, **kwargs):
        try:
            return func(instance, *args, **kwargs)
        except Exception as e:
            LOG.error("%(function_name)s %(exception_desc)s",
                      {'function_name': func.__name__,
                      'exception_desc': str(e)})
    return wrapper


class SDNMechanismDriver(api.MechanismDriver):

    """Mechanism Driver for SDN.

    This driver send notifications to SDN provider.
    The notifications are for port/network changes.
    """

    supported_device_owners = [neutron_const.DEVICE_OWNER_DHCP,
                               neutron_const.DEVICE_OWNER_ROUTER_INTF,
                               neutron_const.DEVICE_OWNER_ROUTER_GW,
                               neutron_const.DEVICE_OWNER_FLOATINGIP]

    def initialize(self):
        if self._is_sdn_sync_enabled():
            self.journal = journal.SdnJournalThread()
            self._start_maintenance_thread()
        self.supported_vnic_types = [portbindings.VNIC_BAREMETAL]
        self.supported_network_types = (
            [neutron_const.TYPE_VLAN, neutron_const.TYPE_FLAT])
        self.vif_type = portbindings.VIF_TYPE_OTHER
        self.vif_details = {}
        SDNMechanismDriver._check_physnet_confs()
        self.allowed_physical_networks = cfg.CONF.sdn.physical_networks
        self.bind_normal_ports = cfg.CONF.sdn.bind_normal_ports
        self.bind_normal_ports_physnets = (
            cfg.CONF.sdn.bind_normal_ports_physnets)

    @staticmethod
    def _check_physnet_confs():
        """Check physical network related ML2 driver configuration options"""

        def _is_sublist(sub, lst):
            return functools.reduce(
                lambda x, y: x & y, map(lambda x: x in lst, sub))

        LOG.debug("physnet Config opts: physical_networks=%s, "
                  "bind_normal_ports=%s, bind_normal_ports_physnets=%s",
                  cfg.CONF.sdn.physical_networks,
                  cfg.CONF.sdn.bind_normal_ports,
                  cfg.CONF.sdn.bind_normal_ports_physnets)

        # Note(adrianc): if `bind_normal_ports` is set then
        # `bind_normal_ports_physnets` must be a subset of `physical_networks`
        if (cfg.CONF.sdn.bind_normal_ports and
                not (sdn_const.ANY in cfg.CONF.sdn.physical_networks) and
                _is_sublist(
                    cfg.CONF.sdn.bind_normal_ports_physnets,
                    cfg.CONF.sdn.physical_networks)):
            raise sdn_excpt.SDNDriverConfError(
                msg="'bind_normal_ports_physnets' configuration option is "
                    "expected to be a subset of 'physical_networks'.")

    @staticmethod
    def _is_sdn_sync_enabled():
        """Whether to synchronise events to an SDN controller."""
        return cfg.CONF.sdn.sync_enabled

    def _is_allowed_physical_network(self, physical_network):
        if (sdn_const.ANY in self.allowed_physical_networks or
                physical_network in self.allowed_physical_networks):
            return True
        return False

    def _is_allowed_physical_networks(self, network_context):
        for network_segment in network_context.network_segments:
            physical_network = network_segment.get('physical_network')
            if not self._is_allowed_physical_network(physical_network):
                return False
        return True

    def _start_maintenance_thread(self):
        # start the maintenance thread and register all the maintenance
        # operations :
        # (1) JournalCleanup - Delete completed rows from journal
        # (2) CleanupProcessing - Mark orphaned processing rows to pending
        cleanup_obj = cleanup.JournalCleanup()
        self._maintenance_thread = maintenance.MaintenanceThread()
        self._maintenance_thread.register_operation(
            cleanup_obj.delete_completed_rows)
        self._maintenance_thread.register_operation(
            cleanup_obj.cleanup_processing_rows)
        self._maintenance_thread.start()

    @staticmethod
    def _record_in_journal(context, object_type, operation, data=None):
        if not SDNMechanismDriver._is_sdn_sync_enabled():
            return
        if data is None:
            data = context.current
        if object_type == sdn_const.PORT:
            SDNMechanismDriver._replace_port_dhcp_opt_name(
                data, DHCP_OPT_CLIENT_ID_NUM, edo_ext.DHCP_OPT_CLIENT_ID)
        journal.record(context._plugin_context.session, object_type,
                       context.current['id'], operation, data)

    @context_validator(sdn_const.NETWORK)
    @error_handler
    def create_network_precommit(self, context):
        network_dic = context.current
        if (self._is_allowed_physical_networks(context) and
            network_dic.get('provider:segmentation_id')):
            network_dic[NETWORK_QOS_POLICY] = (
                self._get_network_qos_policy(context, network_dic['id']))
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.NETWORK, sdn_const.POST, network_dic)

    @context_validator()
    @error_handler
    def bind_port(self, context):
        if not self._is_allowed_physical_networks(context.network):
            return
        port_dic = context.current
        if self._is_send_bind_port(port_dic):
            port_dic[NETWORK_QOS_POLICY] = (
                self._get_network_qos_policy(context, port_dic['network_id']))
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.PORT, sdn_const.POST, port_dic)

        segments = context.network.network_segments
        for segment in segments:
            if (segment[api.NETWORK_TYPE] != neutron_const.TYPE_FLAT and
                    not self._is_sdn_sync_enabled()):
                # Don't bind to non-flat networks if not syncing to an SDN
                # controller.
                continue

            # set port to active if supported
            if self._is_port_set_binding_supported(port_dic, segment):
                context.set_binding(segment[api.ID],
                                    self.vif_type,
                                    self.vif_details,
                                    neutron_const.PORT_STATUS_ACTIVE)

    @context_validator(sdn_const.NETWORK)
    @error_handler
    def update_network_precommit(self, context):
        network_dic = context.current
        if (self._is_allowed_physical_networks(context)):
            network_dic[NETWORK_QOS_POLICY] = (
                self._get_network_qos_policy(context, network_dic['id']))
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.NETWORK, sdn_const.PUT, network_dic)

    def _get_client_id_from_port(self, port):
        dhcp_opts = port.get('extra_dhcp_opts', [])
        for dhcp_opt in dhcp_opts:
            if (isinstance(dhcp_opt, dict) and
                    dhcp_opt.get('opt_name') in (edo_ext.DHCP_OPT_CLIENT_ID,
                                                 DHCP_OPT_CLIENT_ID_NUM)):
                return dhcp_opt.get('opt_value')

    @staticmethod
    def _replace_port_dhcp_opt_name(port, old_opt_name, new_opt_name):
        dhcp_opts = port.get('extra_dhcp_opts', [])
        for dhcp_opt in dhcp_opts:
            if (isinstance(dhcp_opt, dict) and
                    dhcp_opt.get('opt_name') == old_opt_name):
                dhcp_opt['opt_name'] = new_opt_name
                return

    def _get_local_link_information(self, port):
        binding_profile = port.get('binding:profile')
        if binding_profile:
            return binding_profile.get('local_link_information')

    def create_port_precommit(self, context):
        if not self._is_allowed_physical_networks(context.network):
            return
        port_dic = context.current
        port_dic[NETWORK_QOS_POLICY] = (
            self._get_network_qos_policy(context, port_dic['network_id']))

        vnic_type = port_dic[portbindings.VNIC_TYPE]
        if (vnic_type == portbindings.VNIC_BAREMETAL and
            (self._get_client_id_from_port(port_dic) or
             self._get_local_link_information(port_dic))):
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.PORT, sdn_const.POST, port_dic)

    def update_port_precommit(self, context):
        if not self._is_allowed_physical_networks(context.network):
            return
        port_dic = context.current
        orig_port_dict = context.original
        port_dic[NETWORK_QOS_POLICY] = (
            self._get_network_qos_policy(context, port_dic['network_id']))

        vnic_type = port_dic[portbindings.VNIC_TYPE]
        # Check if we get a client id after binding the bare metal port,
        # and report the port to neo
        if vnic_type == portbindings.VNIC_BAREMETAL:
            # Ethernet Case
            link__info = self._get_local_link_information(port_dic)
            orig_link_info = self._get_local_link_information(orig_port_dict)
            if link__info != orig_link_info and link__info:
                SDNMechanismDriver._record_in_journal(
                    context, sdn_const.PORT, sdn_const.POST, port_dic)
                return
            elif (orig_link_info and orig_port_dict[portbindings.HOST_ID] and
                    not port_dic[portbindings.HOST_ID]):
                SDNMechanismDriver._record_in_journal(
                    context, sdn_const.PORT, sdn_const.DELETE, orig_port_dict)
                return
            # InfiniBand Case
            current_client_id = self._get_client_id_from_port(port_dic)
            orig_client_id = self._get_client_id_from_port(orig_port_dict)
            if current_client_id != orig_client_id:
                SDNMechanismDriver._record_in_journal(
                    context, sdn_const.PORT, sdn_const.POST, port_dic)
                return
            elif (orig_client_id and orig_port_dict[portbindings.HOST_ID] and
                    not port_dic[portbindings.HOST_ID]):
                SDNMechanismDriver._record_in_journal(
                    context, sdn_const.PORT, sdn_const.DELETE, orig_port_dict)
                return
        # delete the port in case instance is deleted
        # and port is created separately
        elif (orig_port_dict[portbindings.HOST_ID] and
              not port_dic[portbindings.HOST_ID] and
              self._is_send_bind_port(orig_port_dict)):
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.PORT, sdn_const.DELETE, orig_port_dict)
        # delete the port in case instance is migrated to another hypervisor
        elif (orig_port_dict[portbindings.HOST_ID] and
              port_dic[portbindings.HOST_ID] !=
              orig_port_dict[portbindings.HOST_ID] and
              self._is_send_bind_port(orig_port_dict)):
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.PORT, sdn_const.DELETE, orig_port_dict)
        else:
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.PORT, sdn_const.PUT, port_dic)

    @context_validator(sdn_const.NETWORK)
    @error_handler
    def delete_network_precommit(self, context):
        if not self._is_allowed_physical_networks(context):
            return
        network_dic = context.current
        network_dic[NETWORK_QOS_POLICY] = (
            self._get_network_qos_policy(context, network_dic['id']))
        SDNMechanismDriver._record_in_journal(
            context, sdn_const.NETWORK, sdn_const.DELETE, data=network_dic)

    @context_validator(sdn_const.PORT)
    @error_handler
    def delete_port_precommit(self, context):
        if not self._is_allowed_physical_networks(context.network):
            return
        port_dic = context.current
        # delete the port only if attached to a host
        vnic_type = port_dic[portbindings.VNIC_TYPE]
        if (port_dic[portbindings.HOST_ID] and
            (vnic_type == portbindings.VNIC_BAREMETAL or
             self._is_send_bind_port(port_dic))):
            port_dic[NETWORK_QOS_POLICY] = (
                self._get_network_qos_policy(context,
                                             port_dic['network_id']))
            SDNMechanismDriver._record_in_journal(
                context, sdn_const.PORT, sdn_const.DELETE, port_dic)

    @journal.call_thread_on_end
    def sync_from_callback(self, operation, res_type, res_id, resource_dict):
        object_type = res_type.singular
        object_uuid = (resource_dict[object_type]['id']
                       if operation == sdn_const.POST else res_id)
        if resource_dict is not None:
            resource_dict = resource_dict[object_type]
        journal.record(db_api.get_session(), object_type, object_uuid,
                       operation, resource_dict)

    def _postcommit(self, context):
        if not self._is_sdn_sync_enabled():
            return
        self.journal.set_sync_event()

    create_network_postcommit = _postcommit
    update_network_postcommit = _postcommit
    create_port_postcommit = _postcommit
    update_port_postcommit = _postcommit
    delete_network_postcommit = _postcommit
    delete_port_postcommit = _postcommit

    def _is_send_bind_port(self, port_context):
        """Verify that bind port is occur in compute context

        The request HTTP will occur only when the device owner is compute
        or when device owner is in self.supported_device_owners
        """
        device_owner = port_context['device_owner']
        return (device_owner and
                (device_owner.lower().startswith(
                 neutron_const.DEVICE_OWNER_COMPUTE_PREFIX) or
                 device_owner in self.supported_device_owners))

    def _is_port_set_binding_supported(self, port, segment):
        """Check if driver is able to bind the port

        Port binding is supported if:
          a. Port VNIC type in supported_vnic_types (currently VNIC_BAREMETAL).
        Or
          b. Port is of VNIC type normal and:
            1. bind_normal_ports cfg opt is set.
            2. The segment's physnet is in bind_normal_ports_physnets cfg opt.
            3. The device owner is DHCP/Router(Non DVR) related port.

        :param port: port object
        :param segment: Segment dictionary representing the network segment
                        to bind on.
        :return: True if port binding is supported by the driver else False.
        """
        vnic_type = port[portbindings.VNIC_TYPE]

        if vnic_type in self.supported_vnic_types:
            return True

        if (vnic_type == portbindings.VNIC_NORMAL and
                self.bind_normal_ports and
                port['device_owner'] in self.supported_device_owners and
                segment.get('physical_network') in
                self.bind_normal_ports_physnets):
            return True
        return False

    def check_segment(self, segment):
        """Verify if a segment is valid for the SDN MechanismDriver.

        Verify if the requested segment is supported by SDN MD and return True
        or False to indicate this to callers.
        """
        network_type = segment[api.NETWORK_TYPE]
        return network_type in self.supported_network_types

    def check_segments(self, segments):
        """Verify if there is a segment in a list of segments that valid for
         the SDN MechanismDriver.

        Verify if the requested segments are supported by SDN MD and return
        True or False to indicate this to callers.
        """
        if segments:
            for segment in segments:
                if self.check_segment(segment):
                    return True
        return False

    def _get_network_qos_policy(self, context, net_id):
        return policy_object.QosPolicy.get_network_policy(
            context._plugin_context, net_id)
