#-*- coding: utf-8 -*-

# Copyright 2012 Calculate Ltd. http://www.calculate-linux.org
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from soaplib.wsgi import Application
import re
import logging, os, OpenSSL
logger = logging.getLogger(__name__)
from lxml import etree
import soaplib
from soaplib.serializers.exception import Fault
from soaplib.serializers.primitive import string_encoding
from soaplib.soap import apply_mtom
import datetime,pickle
from decorators import Dec

# for OpenSSLAdapter
from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter

HTTP_500 = '500 Internal server error'
HTTP_200 = '200 OK'
HTTP_405 = '405 Method Not Allowed'
HTTP_403 = '403 Forbidden'
not_log_list = ['post_server_request', 'post_client_request', 'del_sid', \
                'get_server_cert', 'get_client_cert', 'get_entire_frame', \
                'get_crl', 'get_server_host_name', 'get_ca', 'get_table', \
                'post_cert', 'post_sid', 'active_client', 'list_pid', \
                'get_methods', 'get_frame', 'get_progress', 'pid_info']

class ClApplication(Application):
    def __init__(self, services, tns, name=None, _with_partnerlink=False, \
                 log=None):
        '''
        @param A ServiceBase subclass that defines the exposed services.
        '''
        Application.__init__(self, services, tns)
        # add object logging
        self.log = log

    # verification of compliance certificate and session (sid)
    def check_cert_sid (self, sid, server):
        import threading
        curThread = threading.currentThread()
        cert = curThread.client_cert
        from cert_cmd import find_cert_id
        cert_id = find_cert_id(cert, server.data_path, server.certbase)
        cert_id = int(cert_id)
        if cert_id == 0:
            return 0

        # session file
        if not os.path.exists(server.sids):
            os.system('mkdir %s' %server.sids)

        if not os.path.isfile(server.sids_file):
            open(server.sids_file, 'w')
        fd = open(server.sids_file, 'r')
        while 1:
            try:
                # read all on one record
                list_sid = pickle.load(fd)
            except Exception:
                break
            # find session id in sids file
            if cert_id == int(list_sid[1]):
                if int(sid) == int(list_sid[0]):
                    return 1
        return 0

    # input parameters - certificate and name method
    def check_rights(self, method_name, req_env, params):
        """ check right client certificate for the method """
        rmethod = re.compile('[{\w]+[}]')
        method_rep = rmethod.findall(method_name)
        method_name = method_name.replace(method_rep[0],'')
        import threading
        curThread = threading.currentThread()

        cert = curThread.client_cert
        server_cert = curThread.server.ssl_certificate
        server_key = curThread.server.ssl_private_key
        certbase = curThread.server.certbase
        rights = curThread.server.rights
        group_rights = curThread.server.group_rights
        data_path = curThread.server.data_path
        ip = req_env.get('REMOTE_ADDR')
        permitted_methods = ['post_server_request', 'post_client_request', \
                             'get_server_cert', 'get_client_cert', \
                             'get_crl', 'get_server_host_name', 'get_ca']

        if method_name in permitted_methods:
            return 1
        if cert == None:
            if not method_name in permitted_methods:
#                self.log.debug('%s NoneCert %s 0 %s' \
#                        %(datetime.datetime.now().__str__(), ip, method_name))
                return 0
#            self.log.debug('%s NoneCert %s 1 %s' \
#                        %(datetime.datetime.now().__str__(), ip, method_name))
            return 1

        if params:
            if hasattr (params, 'sid') and (method_name in Dec.rightsMethods \
                                    or (method_name.endswith('_view') and \
                                    method_name[:-5] in Dec.rightsMethods)):
                if not self.check_cert_sid(params.sid, curThread.server):
                    return 0

        data_server_cert = open(server_cert,'r').read()
        certobj = OpenSSL.crypto.load_certificate \
                                  (OpenSSL.SSL.FILETYPE_PEM, data_server_cert)

        data_server_key = open(server_key,'r').read()
        Pkey = OpenSSL.crypto.load_privatekey(OpenSSL.SSL.FILETYPE_PEM, \
                                                    data_server_key, 'qqqq')
        signature = OpenSSL.crypto.sign(Pkey, cert, 'SHA1')
        try:
            OpenSSL.crypto.verify(certobj, signature, cert, 'SHA1')
        except Exception, e:
            print e
            return 0
        if method_name == 'cert_add':
            return 0
        certobj_cl = OpenSSL.crypto.load_certificate \
                                    (OpenSSL.SSL.FILETYPE_PEM, cert)
        try:
            com = certobj_cl.get_extension(certobj_cl.get_extension_count()-1)
            groups = com.get_data().split(':')[1]
        except IndexError as e:
            groups = ""
        except Exception as e:
            return 0
        groups_list = groups.split(',')
        cert_id = 0
        # open certificates database
        if not os.path.exists(certbase):
            fc = open(certbase,"w")
            fc.close()
        from cert_cmd import find_cert_id
        checked_id = find_cert_id(cert, data_path, certbase)
        cert_id = int(checked_id)
        count = 0
        flag = 0
        find_flag = False
        # if certificate found
        if cert_id > 0:
            if not method_name in Dec.rightsMethods:
                return 1

            # if group = all and not redefined group all
            if 'all' in groups_list:
                find_flag = False
                fd = open(group_rights, 'r')
                t = fd.read()
                # find all in group_rights file
                for line in t.splitlines():
                    if not line:
                        continue
                    if line.split()[0] == 'all':
                        find_flag = True
                        break
#                if not find_flag:
#                    return 1

            for right_param in Dec.rightsMethods[method_name]:
                flag = 0
                try:
                    # check rights
                    if not os.path.exists (rights):
                        open(rights,'w')
                    with open(rights) as fr:
                        t = fr.read()
                        for line in t.splitlines():
                            words = line.split()
                            # first word in line equal name input method
                            if words[0] == right_param:
                                for word in words:
                                    try:
                                        word = int(word)
                                    except:
                                        continue
                                    # compare with certificat number
                                    if cert_id == word:
                                        # if has right
                                        count += 1
                                        flag = 1
                                        break
                                    if cert_id == -word:
                                        return 0
                            if flag: break

                    if flag: break
                    # open file with groups rights
                    if not os.path.exists (group_rights):
                        open(group_rights,'w')
                    with open(group_rights) as fd:
                        t = fd.read()
                        for line in t.splitlines():
                            if not line:
                                continue
                            words = line.split(' ',1)
                            # first word in line equal name input method
                            if words[0] in groups_list:
                                methods = words[1].split(',')
                                for word in methods:
                                    # compare with certificat number
                                    if right_param == word.strip():
                                        # if has right
                                        count += 1
                                        flag = 1
                                        break
                            if flag: break
                except:
                    return 0
            if count == len (Dec.rightsMethods[method_name]):
                return 1
        if not find_flag and 'all' in groups_list:
            return 1
        elif method_name in ['post_cert','init_session']:
            return 1
        return 0

    def create_path(self):
        """ create paths for server files """
        import threading
        curThread = threading.currentThread()
        data_path = curThread.server.data_path
        sids = curThread.server.sids
        pids = curThread.server.pids
        cert_path = curThread.server.cert_path
        if not os.path.exists(sids):
            if not os.path.exists(data_path):
                os.makedirs(data_path)
            os.makedirs(sids)
        if not os.path.exists(pids):
            if not os.path.exists(data_path):
                os.makedirs(data_path)
            os.makedirs(pids)
        if not os.path.exists(data_path + '/conf'):
            if not os.path.exists(data_path):
                os.makedirs(data_path)
            os.makedirs(data_path + '/conf')

        if not os.path.exists(data_path + '/conf/right.conf'):
            open(data_path + '/conf/right.conf', 'w')

        if not os.path.exists(data_path + '/conf/group_right.conf'):
            open(data_path + '/conf/group_right.conf', 'w')

        if not os.path.exists(data_path + '/client_certs'):
            os.makedirs(data_path + '/client_certs')

        if not os.path.exists(data_path + '/server_certs'):
            os.makedirs(data_path + '/server_certs')

        if not os.path.exists(cert_path):
            os.makedirs(cert_path)


    def _Application__handle_soap_request(self, req_env, start_response, url):
        """
        This function is too big.
        """
        import threading
        curThread = threading.currentThread()
        curThread.REMOTE_ADDR = req_env.get('REMOTE_ADDR')
        curThread.REMOTE_PORT = req_env.get('REMOTE_PORT')
        ip = req_env.get('REMOTE_ADDR')

        http_resp_headers = {
            'Content-Type': 'text/xml',
            'Content-Length': '0',
        }
        method_name = None

        self.create_path()
        try:
            # implementation hook
            self.on_call(req_env)

            if req_env['REQUEST_METHOD'].lower() != 'post':
                http_resp_headers['Allow'] = 'POST'
                start_response(HTTP_405, http_resp_headers.items())
                return ['']

            input = req_env.get('wsgi.input')
            length = req_env.get("CONTENT_LENGTH")
            body = input.read(int(length))

            try:
                service = None
                soap_req_header, soap_req_payload = \
                self._Application__decode_soap_request(req_env, body)
                if not (soap_req_payload is None):
                    self.validate_request(soap_req_payload)

                method_name = \
                    self._Application__get_method_name \
                                                (req_env, soap_req_payload)
                if method_name is None:
                    resp = "Could not extract method name from the request!"
                    http_resp_headers['Content-Length'] = str(len(resp))
                    start_response(HTTP_500, http_resp_headers.items())
                    return [resp]

                service_class = self.get_service_class(method_name)
                service = self.get_service(service_class, req_env)

            finally:
                # for performance reasons, we don't want the following to run
                # in production even though we won't see the results.
                if logger.level == logging.DEBUG:
                    try:
                        logger.debug(etree.tostring(etree.fromstring(body),
                                                           pretty_print=True))
                    except etree.XMLSyntaxError,e:
                        logger.debug(body)
                        raise Fault('Client.XMLSyntax',\
                                    'Error at line: %d, col: %d' % e.position)

            # retrieve the method descriptor
            descriptor = service.get_method(method_name)
            func = getattr(service, descriptor.name)

            # decode header object
            if soap_req_header is not None and len(soap_req_header) > 0:
                in_header = descriptor.in_header
                service.soap_in_header = in_header.from_xml(soap_req_header)

            # decode method arguments
            if soap_req_payload is not None and len(soap_req_payload) > 0:
                params = descriptor.in_message.from_xml(soap_req_payload)
            else:
                params = [None] * len(descriptor.in_message._type_info)

        #### check_rights
            import threading
            curThread = threading.currentThread()
            if hasattr (params, 'sid'):
                curThread.lang = service.get_lang(params.sid)
            # check exists client certificate
            if not hasattr (curThread, 'client_cert'):
                curThread.client_cert = None
            # check rights client certificate for the method
            check = self.check_rights(method_name, req_env, params)
            if not check:
                if curThread.client_cert:
                    certobj = OpenSSL.crypto.load_certificate \
                            (OpenSSL.SSL.FILETYPE_PEM, curThread.client_cert)
                    finger = certobj.digest('SHA1')
                    self.log.debug('%s %s %s forbidden %s' \
                            %(datetime.datetime.now().__str__(), finger, ip, \
                            method_name[5:]))
                resp = "Permission denied: " + method_name
                http_resp_headers['Content-Length'] = str(len(resp))
                start_response(HTTP_403, http_resp_headers.items())
                return [resp]

        #### logging
            if curThread.client_cert:
                certobj = OpenSSL.crypto.load_certificate \
                        (OpenSSL.SSL.FILETYPE_PEM, curThread.client_cert)
                finger = certobj.digest('SHA1')
                if not method_name[5:] in not_log_list and \
                                        not method_name[5:].endswith('_view'):
                    self.log.debug('%s %s %s allowed %s' \
                            %(datetime.datetime.now().__str__(), finger, ip, \
                            method_name[5:]))

            # implementation hook
            service.on_method_call(req_env, method_name, params,
                                                             soap_req_payload)

            # call the method
            result_raw = service.call_wrapper(func, params)

            # construct the soap response, and serialize it
            envelope = etree.Element('{%s}Envelope' % soaplib.ns_soap_env,
                                                          nsmap=soaplib.nsmap)

            #
            # header
            #
            soap_header_elt = etree.SubElement(envelope,
                                           '{%s}Header' % soaplib.ns_soap_env)

            if service.soap_out_header != None:
                if descriptor.out_header is None:
                    logger.warning("Skipping soap response header as %r "
                                   "method is not published to have a soap "
                                   "response header" % method_name)
                else:
                    descriptor.out_header.to_xml(
                        service.soap_out_header,
                        self.get_tns(),
                        soap_header_elt,
                        descriptor.out_header.get_type_name()
                    )
            if len(soap_header_elt) > 0:
                envelope.append(soap_header_elt)

            #
            # body
            #
            soap_body = etree.SubElement(envelope,
                                            '{%s}Body' % soaplib.ns_soap_env)

            # instantiate the result message
            result_message = descriptor.out_message()

            # assign raw result to its wrapper, result_message
            out_type = descriptor.out_message._type_info

            if len(out_type) > 0:
                if len(out_type) == 1:
                    attr_name = descriptor.out_message._type_info.keys()[0]
                    setattr(result_message, attr_name, result_raw)
                else:
                    for i in range(len(out_type)):
                        attr_name = descriptor.out_message._type_info.keys()[i]
                        setattr(result_message, attr_name, result_raw[i])

            # transform the results into an element
            descriptor.out_message.to_xml(result_message, self.get_tns(),
                                                                    soap_body)
            # implementation hook
            service.on_method_return(req_env, result_raw, soap_body,
                                                            http_resp_headers)

            #
            # misc
            #
            results_str = etree.tostring(envelope, xml_declaration=True,
                                                     encoding=string_encoding)

            if descriptor.mtom:
                http_resp_headers, results_str = apply_mtom(http_resp_headers,
                               results_str, descriptor.out_message._type_info,
                               [result_raw])

            # implementation hook
            self.on_return(req_env, http_resp_headers, results_str)

            # initiate the response
            http_resp_headers['Content-Length'] = str(len(results_str))
            start_response(HTTP_200, http_resp_headers.items())

            if logger.level == logging.DEBUG:
                logger.debug('\033[31m'+ "Response" + '\033[0m')
                logger.debug(etree.tostring(envelope, xml_declaration=True,
                                                            pretty_print=True))
            # return the serialized results
            return [results_str]

        # The user issued a Fault, so handle it just like an exception!
        except Fault, e:
            return self._Application__handle_fault(req_env, start_response,
                                                http_resp_headers, service, e)

        except Exception, e:
            fault = Fault('Server', str(e))

            return self._Application__handle_fault(req_env, start_response,
                                            http_resp_headers, service, fault)

class OpenSSLAdapter (pyOpenSSLAdapter):
    def verify_func(self, connection, x509, errnum, errdepth, ok):
        # get client certificate
        import threading
        curThread = threading.currentThread()
        if errdepth == 0:
            curThread.client_cert = OpenSSL.crypto.dump_certificate \
                                        (OpenSSL.crypto.FILETYPE_PEM, x509)
        else:
            curThread.client_cert = None
        return ok

    def get_context(self):
        """Return an SSL.Context from self attributes."""
        # See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473
        c = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
        #c.set_passwd_cb(lambda *unused: 'qqqq')
        c.use_privatekey_file(self.private_key)
        c.set_verify(OpenSSL.SSL.VERIFY_PEER, self.verify_func)

        if self.certificate_chain:
            c.load_verify_locations(self.certificate_chain)

        c.use_certificate_file(self.certificate)
        return c
