__init__.py 15.0 KB
from datetime import datetime
from logging import error
from flasgger import swag_from
from app.util import BlueprintApi
from flask import Blueprint, render_template, redirect, request, session, jsonify, flash, make_response
from .models import *
from .oauth2 import authorization, generate_user_info, require_oauth
from authlib.oauth2 import OAuth2Error
from authlib.integrations.flask_oauth2 import current_token
from . import user_create, client_create, client_query, user_query, user_update, user_delete, auth_log_query
import configure
from app.decorators.auth_decorator import auth_decorator
import time
from app.models import SM3, AESHelper
from app.util.component.StructurePrint import StructurePrint
import traceback
from oauthlib import oauth2
import requests
from app.modules.auth.models import OAuth2Token, User, db, OAuthLog
from app.util.enum.AuthEnum import AuthEnum, OriginEnum, OperateEnum


def current_user():
    if "id" in session:
        uid = session["id"]
        return User.query.get(uid)
    return None


def remove_user():
    user = current_user()
    if user:
        session.pop("id")


def split_by_crlf(s):
    return [v for v in s.splitlines() if v]


def getRedirectUrl(request):  # 获取重定向地址
    # 获取头部信息
    X_Forwarded_Proto = request.headers.get("X-Forwarded-Proto")  # 协议
    X_Forwarded_Host = request.headers.get("X-Forwarded-Host")  # host
    if not X_Forwarded_Proto == None and not X_Forwarded_Host == None:
        return X_Forwarded_Proto+"://"+X_Forwarded_Host
    else:
        return request.host_url.rstrip("/")


class DataManager(BlueprintApi):
    bp = Blueprint("Auth", __name__, url_prefix="/auth")

    @staticmethod
    @bp.route("/authorize", methods=("GET", "POST"))
    def authorize():
        user = current_user()
        request2 = authorization.create_oauth2_request(request)
        grant2 = authorization.get_authorization_grant(request=request2)
        redirect_uri = grant2.validate_authorization_request()
        session["redirect_uri"] = redirect_uri  # 记录跳转重定向地址
        if request.method == "GET":
            # 没有登录,跳转到登录界面
            try:
                grant = authorization.validate_consent_request(end_user=user)
            except OAuth2Error as error:
                return jsonify(dict(error.get_body()))
            if not user:
                # 生成验证码

                return render_template("auth/authorize.html",
                                       user=user,
                                       grant=grant)
        error = ""
        if not user:
            # 验证码校验

            if not "username" in request.form or not request.form.get("username"):
                error = "用户名不可为空"
            elif not "password" in request.form or not request.form.get("password"):
                error = "密码不可为空"
            else:
                username = request.form.get("username")
                crypt_pwd = request.form.get("password")
                # password = SM3.encode(crypt_pwd)
                password = SM3.encode(AESHelper.decode(crypt_pwd))

                # 仅支持dmap平台保留用户登录
                origin_type = OriginEnum.Dmap.name.lower()
                user = User.query.filter_by(
                    username=username, password=password, origin=origin_type).first()
                if not user:
                    error = "账号或密码不正确"
                    flash(error)
                    # return render_template("auth/authorize.html",
                    #                    grant_user=None,error=error)

        flash(error)
        if user:
            session["id"] = user.id
            grant_user = user

            # 日志
            log = OAuthLog(user_id=user.id, username=user.username,
                           auth_type=AuthEnum.Other.name.lower(),
                           message="认证成功", create_time=datetime.now(),
                           operate_type=OperateEnum.Login,
                           displayname=user.displayname, ip=request.remote_addr
                           )
            db.session.add(log)
            db.session.commit()

            return authorization.create_authorization_response(request=request, grant_user=grant_user)

        # try:
        #     grant = authorization.validate_consent_request(end_user=user)
        # except OAuth2Error as error:
        #     return jsonify(dict(error.get_body()))
        # return render_template("auth/authorize.html", user=user, grant=grant, error=error)
        try:
            grant = authorization.validate_consent_request(end_user=user)
        except OAuth2Error as error:
            return jsonify(dict(error.get_body()))
        return render_template("auth/authorize.html",
                               grant_user=None, error=error)

    @staticmethod
    @bp.route("/token", methods=["POST"])
    def issue_token():
        return authorization.create_token_response()

    @staticmethod
    @bp.route("/userinfo")
    @require_oauth("profile")
    def api_me():
        try:
            return jsonify(generate_user_info(current_token.user, current_token.scope))
        except error as e:
            return jsonify(dict(e.get_body()))

    @staticmethod
    @bp.route("/logout", methods=["GET"])
    def logout():
        try:
            request2 = authorization.create_oauth2_request(request)
            grant1 = authorization.get_authorization_grant(
                request=request2)
            redirect_uri = grant1.validate_authorization_request()
            access_token = request.args.get("accesstoken")

            if not access_token == None:
                accesstoken = OAuth2Token.query.filter_by(
                    access_token=access_token).one_or_none()
                if not accesstoken == None:
                    accesstoken.revoked = True
                    db.session.commit()
                    if current_user() != None:
                        remove_user()

                    user = User.query.get(accesstoken.user_id)
                    # 日志
                    if user != None:
                        log = OAuthLog(user_id=user.id, username=user.username,
                                       auth_type=AuthEnum.Other.name.lower(),
                                       message="注销成功", create_time=datetime.now(),
                                       operate_type=OperateEnum.Logout, token=access_token,
                                       displayname=user.displayname, ip=request.remote_addr
                                       )
                        db.session.add(log)
                        db.session.commit()

        except OAuth2Error as error:
            StructurePrint().print(error.__str__()+":" + traceback.format_exc(), "error")
        return redirect(redirect_uri)

    """接口"""
    @staticmethod
    @bp.route("/users", methods=["GET"])
    @swag_from(user_query.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_query():
        """
        获取用户列表
        """
        return user_query.Api().result

    @staticmethod
    @bp.route("/users", methods=["POST"])
    @swag_from(user_create.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_create():
        """
        创建用户
        """
        return user_create.Api().result

    @staticmethod
    @bp.route("/userEdit", methods=["POST"])
    @swag_from(user_update.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_update():
        """
        更新用户信息
        """
        return user_update.Api().result

    @staticmethod
    @bp.route("/userDelete", methods=["POST"])
    @swag_from(user_delete.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_delete():
        """
        删除用户
        """
        return user_delete.Api().result

    @staticmethod
    @bp.route("/client", methods=["POST"])
    @swag_from(client_create.Api.api_doc)
    def client_create():
        """
        创建client
        """
        return client_create.Api().result

    @staticmethod
    @bp.route("/client", methods=["GET"])
    @swag_from(client_query.Api.api_doc)
    def client_query():
        """
        获取client列表
        """
        return client_query.Api().result

    @staticmethod
    @bp.route("/init", methods=["GET"])
    def init():
        try:
            username = 'admin'
            password = SM3.encode('DMap@123')
            if not User.query.filter_by(username=username).one_or_none():
                user = User(username=username, password=password, role='admin',
                            phone='', company='', position='', email='',
                            create_time=time.strftime(
                                "%Y-%m-%d %H:%M:%S", time.localtime()),
                            update_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
                db.session.add(user)
                db.session.commit()
                return "创建默认用户成功"
            else:
                return "默认用户已存在"
        except Exception as e:
            StructurePrint().print(e.__str__()+":" + traceback.format_exc(), "error")

    # @staticmethod
    # @bp.route("/translate", methods=["GET"])
    # def translate():
    #     password = ['esri@123', 'admin', 'DMap@123', 'passwd','dci112..']
    #     result = {}
    #     for p in password:
    #         new_pwd = SM3.encode(p)
    #         result[p] = new_pwd
    #     return result

    '''
    三方登录:OA
    '''
    @staticmethod
    @bp.route("/oa", methods=["GET"])
    def oa_authorization():
        client = oauth2.WebApplicationClient(
            configure.OA["client_id"])
        state = client.state_generator()
        StructurePrint().print(request.headers, "info")
        auth_uri = client.prepare_request_uri(
            configure.OA["authorization_endpoint"], getRedirectUrl(request) + configure.OA["redirect_uri"], configure.OA["scope"], state)
        session["oauth_state"] = state
        return redirect(auth_uri)

    '''
    oa三方登录回调
    '''
    @staticmethod
    @bp.route("/oa/callback", methods=["GET"])
    def oa_callback():
        try:
            auth_default_redirect_uri = configure.auth_default_redirect_uri
            client = oauth2.WebApplicationClient(
                configure.OA["client_id"])

            # 获取code
            code = client.parse_request_uri_response(
                request.url, session["oauth_state"]).get("code")

            if code == None:
                return "登录失败"

            # 获取token
            body = client.prepare_request_body(
                code, redirect_uri=getRedirectUrl(request) + configure.OA["redirect_uri"], client_secret=configure.OA["client_secret"])

            r = requests.post(configure.OA["token_endpoint"], body, headers={
                "Content-Type": "application/x-www-form-urlencoded"})

            tokeninfo = r.json()
            access_token = tokeninfo.get("access_token")
            id_token = tokeninfo.get("id_token")

            origin_type = "dci_oa"  # 三方登录标识
            if access_token:
                # 获取用户信息
                userinfo_url = configure.OA["userinfo_endpoint"]
                user_request = requests.get(userinfo_url, headers={
                                            "Authorization": "Bearer %s" % access_token})
                userinfo = user_request.json()
                user_name = userinfo.get("user_name")
                display_name = userinfo.get("displayname")
                display_name = display_name.split(
                )[-1] if display_name != None else user_name

                # 默认关联dmap用户
                try:
                    user = User.query.filter_by(
                        username=user_name, origin=origin_type).first()
                except error as e:
                    user = None

                # 用户不存在,创建用户

                if not user:
                    user = User(username=user_name, password=SM3.encode('DMap@123'), role='dataman',
                                phone='', company='', position='', email='', displayname=display_name,
                                origin=origin_type,
                                create_time=time.strftime(
                                    "%Y-%m-%d %H:%M:%S", time.localtime()),
                                update_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
                    db.session.add(user)
                    db.session.commit()

                session["id"] = user.id

                # dmap token授权
                # 存入数据库
                token = OAuth2Token(
                    client_id=configure.OA["client_id"],
                    token_type=tokeninfo.get("token_type"),
                    access_token=access_token,
                    scope=tokeninfo.get("scope"),
                    expires_in=tokeninfo.get("expires_in"),
                    user_id=user.id
                )

                db.session.add(token)
                db.session.commit()

                redirect_uri = session["redirect_uri"] if "redirect_uri" in session else auth_default_redirect_uri

                #session["id_token"] = id_token
                response = make_response(redirect(redirect_uri))
                response.set_cookie('accessToken', access_token,
                                    max_age=configure.expiretime)
                response.set_cookie('id_token', id_token,
                                    max_age=configure.expiretime)

                log = OAuthLog(user_id=user.id, username=user_name,
                               auth_type=AuthEnum.Other.name.lower(),
                               message="三方认证成功", create_time=datetime.now(),
                               operate_type=OperateEnum.Login, token=access_token,
                               displayname=display_name, ip=request.remote_addr)
                db.session.add(log)
                db.session.commit()

                return response
            else:
                raise Exception("缺少access_token")

        except Exception as e:
            StructurePrint().print(e.__str__()+":" + traceback.format_exc(), "error")
            pop_list = ["id", "redirect_uri"]
            for p in pop_list:
                if p in session:
                    session.pop(p)
        return redirect(auth_default_redirect_uri)

    @staticmethod
    @bp.route("/logs", methods=["GET"])
    @swag_from(auth_log_query.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def authLog():
        '''
        登录日志
        '''
        return auth_log_query.Api().result