__init__.py 15.0 KB
from datetime import datetime
from logging import error
from flasgger import swag_from
from app.util import BlueprintApi
from flask import Blueprint, render_template, redirect, request, session, jsonify, flash, make_response
from .models import *
from .oauth2 import authorization, generate_user_info, require_oauth
from authlib.oauth2 import OAuth2Error
from authlib.integrations.flask_oauth2 import current_token
from . import user_create, client_create, client_query, user_query, user_update, user_delete, auth_log_query
import configure
from app.decorators.auth_decorator import auth_decorator
import time
from app.models import SM3, AESHelper
from app.util.component.StructurePrint import StructurePrint
import traceback
# from oauthlib import oauth2
import requests
from app.modules.auth.models import OAuth2Token, User, db, OAuthLog
from app.util.enum.AuthEnum import AuthEnum, OriginEnum, OperateEnum


def current_user():
    if "id" in session:
        uid = session["id"]
        return User.query.get(uid)
    return None


def remove_user():
    user = current_user()
    if user:
        session.pop("id")


def split_by_crlf(s):
    return [v for v in s.splitlines() if v]


def getRedirectUrl(request):  # 获取重定向地址
    # 获取头部信息
    X_Forwarded_Proto = request.headers.get("X-Forwarded-Proto")  # 协议
    X_Forwarded_Host = request.headers.get("X-Forwarded-Host")  # host
    if not X_Forwarded_Proto == None and not X_Forwarded_Host == None:
        return X_Forwarded_Proto+"://"+X_Forwarded_Host
    else:
        return request.host_url.rstrip("/")


class DataManager(BlueprintApi):
    bp = Blueprint("Auth", __name__, url_prefix="/auth")

    @staticmethod
    @bp.route("/authorize", methods=("GET", "POST"))
    def authorize():
        user = current_user()
        request2 = authorization.create_oauth2_request(request)
        grant2 = authorization.get_authorization_grant(request=request2)
        redirect_uri = grant2.validate_authorization_request()
        session["redirect_uri"] = redirect_uri  # 记录跳转重定向地址
        if request.method == "GET":
            # 没有登录,跳转到登录界面
            try:
                grant = authorization.validate_consent_request(end_user=user)
            except OAuth2Error as error:
                return jsonify(dict(error.get_body()))
            if not user:
                # 生成验证码

                return render_template("auth/authorize.html",
                                       user=user,
                                       grant=grant)
        error = ""
        if not user:
            # 验证码校验

            if not "username" in request.form or not request.form.get("username"):
                error = "用户名不可为空"
            elif not "password" in request.form or not request.form.get("password"):
                error = "密码不可为空"
            else:
                username = request.form.get("username")
                crypt_pwd = request.form.get("password")
                # password = SM3.encode(crypt_pwd)
                password = SM3.encode(AESHelper.decode(crypt_pwd))

                # 仅支持dmap平台保留用户登录
                origin_type = OriginEnum.Dmap.name.lower()
                user = User.query.filter_by(
                    username=username, password=password, origin=origin_type).first()
                if not user:
                    error = "账号或密码不正确"
                    flash(error)
                    # return render_template("auth/authorize.html",
                    #                    grant_user=None,error=error)

        flash(error)
        if user:
            session["id"] = user.id
            grant_user = user

            # 日志
            log = OAuthLog(user_id=user.id, username=user.username,
                           auth_type=AuthEnum.Other.name.lower(),
                           message="认证成功", create_time=datetime.now(),
                           operate_type=OperateEnum.Login,
                           displayname=user.displayname, ip=request.remote_addr
                           )
            db.session.add(log)
            db.session.commit()

            return authorization.create_authorization_response(request=request, grant_user=grant_user)

        # try:
        #     grant = authorization.validate_consent_request(end_user=user)
        # except OAuth2Error as error:
        #     return jsonify(dict(error.get_body()))
        # return render_template("auth/authorize.html", user=user, grant=grant, error=error)
        try:
            grant = authorization.validate_consent_request(end_user=user)
        except OAuth2Error as error:
            return jsonify(dict(error.get_body()))
        return render_template("auth/authorize.html",
                               grant_user=None, error=error)

    @staticmethod
    @bp.route("/token", methods=["POST"])
    def issue_token():
        return authorization.create_token_response()

    @staticmethod
    @bp.route("/userinfo")
    @require_oauth("profile")
    def api_me():
        try:
            return jsonify(generate_user_info(current_token.user, current_token.scope))
        except error as e:
            return jsonify(dict(e.get_body()))

    @staticmethod
    @bp.route("/logout", methods=["GET"])
    def logout():
        try:
            request2 = authorization.create_oauth2_request(request)
            grant1 = authorization.get_authorization_grant(
                request=request2)
            redirect_uri = grant1.validate_authorization_request()
            access_token = request.args.get("accesstoken")

            if not access_token == None:
                accesstoken = OAuth2Token.query.filter_by(
                    access_token=access_token).one_or_none()
                if not accesstoken == None:
                    accesstoken.revoked = True
                    db.session.commit()
                    if current_user() != None:
                        remove_user()

                    user = User.query.get(accesstoken.user_id)
                    # 日志
                    if user != None:
                        log = OAuthLog(user_id=user.id, username=user.username,
                                       auth_type=AuthEnum.Other.name.lower(),
                                       message="注销成功", create_time=datetime.now(),
                                       operate_type=OperateEnum.Logout, token=access_token,
                                       displayname=user.displayname, ip=request.remote_addr
                                       )
                        db.session.add(log)
                        db.session.commit()

        except OAuth2Error as error:
            StructurePrint().print(error.__str__()+":" + traceback.format_exc(), "error")
        return redirect(redirect_uri)

    """接口"""
    @staticmethod
    @bp.route("/users", methods=["GET"])
    @swag_from(user_query.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_query():
        """
        获取用户列表
        """
        return user_query.Api().result

    @staticmethod
    @bp.route("/users", methods=["POST"])
    @swag_from(user_create.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_create():
        """
        创建用户
        """
        return user_create.Api().result

    @staticmethod
    @bp.route("/userEdit", methods=["POST"])
    @swag_from(user_update.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_update():
        """
        更新用户信息
        """
        return user_update.Api().result

    @staticmethod
    @bp.route("/userDelete", methods=["POST"])
    @swag_from(user_delete.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def user_delete():
        """
        删除用户
        """
        return user_delete.Api().result

    @staticmethod
    @bp.route("/client", methods=["POST"])
    @swag_from(client_create.Api.api_doc)
    def client_create():
        """
        创建client
        """
        return client_create.Api().result

    @staticmethod
    @bp.route("/client", methods=["GET"])
    @swag_from(client_query.Api.api_doc)
    def client_query():
        """
        获取client列表
        """
        return client_query.Api().result

    @staticmethod
    @bp.route("/init", methods=["GET"])
    def init():
        try:
            username = 'admin'
            password = SM3.encode('DMap@123')
            if not User.query.filter_by(username=username).one_or_none():
                user = User(username=username, password=password, role='admin',
                            phone='', company='', position='', email='',
                            create_time=time.strftime(
                                "%Y-%m-%d %H:%M:%S", time.localtime()),
                            update_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
                db.session.add(user)
                db.session.commit()
                return "创建默认用户成功"
            else:
                return "默认用户已存在"
        except Exception as e:
            StructurePrint().print(e.__str__()+":" + traceback.format_exc(), "error")

    # @staticmethod
    # @bp.route("/translate", methods=["GET"])
    # def translate():
    #     password = ['esri@123', 'admin', 'DMap@123', 'passwd','dci112..']
    #     result = {}
    #     for p in password:
    #         new_pwd = SM3.encode(p)
    #         result[p] = new_pwd
    #     return result

    '''
    三方登录:OA
    '''
    @staticmethod
    @bp.route("/oa", methods=["GET"])
    def oa_authorization():
        client = oauth2.WebApplicationClient(
            configure.OA["client_id"])
        state = client.state_generator()
        StructurePrint().print(request.headers, "info")
        auth_uri = client.prepare_request_uri(
            configure.OA["authorization_endpoint"], getRedirectUrl(request) + configure.OA["redirect_uri"], configure.OA["scope"], state)
        session["oauth_state"] = state
        return redirect(auth_uri)

    '''
    oa三方登录回调
    '''
    @staticmethod
    @bp.route("/oa/callback", methods=["GET"])
    def oa_callback():
        try:
            auth_default_redirect_uri = configure.auth_default_redirect_uri
            client = oauth2.WebApplicationClient(
                configure.OA["client_id"])

            # 获取code
            code = client.parse_request_uri_response(
                request.url, session["oauth_state"]).get("code")

            if code == None:
                return "登录失败"

            # 获取token
            body = client.prepare_request_body(
                code, redirect_uri=getRedirectUrl(request) + configure.OA["redirect_uri"], client_secret=configure.OA["client_secret"])

            r = requests.post(configure.OA["token_endpoint"], body, headers={
                "Content-Type": "application/x-www-form-urlencoded"})

            tokeninfo = r.json()
            access_token = tokeninfo.get("access_token")
            id_token = tokeninfo.get("id_token")

            origin_type = "dci_oa"  # 三方登录标识
            if access_token:
                # 获取用户信息
                userinfo_url = configure.OA["userinfo_endpoint"]
                user_request = requests.get(userinfo_url, headers={
                                            "Authorization": "Bearer %s" % access_token})
                userinfo = user_request.json()
                user_name = userinfo.get("user_name")
                display_name = userinfo.get("displayname")
                display_name = display_name.split(
                )[-1] if display_name != None else user_name

                # 默认关联dmap用户
                try:
                    user = User.query.filter_by(
                        username=user_name, origin=origin_type).first()
                except error as e:
                    user = None

                # 用户不存在,创建用户

                if not user:
                    user = User(username=user_name, password=SM3.encode('DMap@123'), role='dataman',
                                phone='', company='', position='', email='', displayname=display_name,
                                origin=origin_type,
                                create_time=time.strftime(
                                    "%Y-%m-%d %H:%M:%S", time.localtime()),
                                update_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
                    db.session.add(user)
                    db.session.commit()

                session["id"] = user.id

                # dmap token授权
                # 存入数据库
                token = OAuth2Token(
                    client_id=configure.OA["client_id"],
                    token_type=tokeninfo.get("token_type"),
                    access_token=access_token,
                    scope=tokeninfo.get("scope"),
                    expires_in=tokeninfo.get("expires_in"),
                    user_id=user.id
                )

                db.session.add(token)
                db.session.commit()

                redirect_uri = session["redirect_uri"] if "redirect_uri" in session else auth_default_redirect_uri

                #session["id_token"] = id_token
                response = make_response(redirect(redirect_uri))
                response.set_cookie('accessToken', access_token,
                                    max_age=configure.expiretime)
                response.set_cookie('id_token', id_token,
                                    max_age=configure.expiretime)

                log = OAuthLog(user_id=user.id, username=user_name,
                               auth_type=AuthEnum.Other.name.lower(),
                               message="三方认证成功", create_time=datetime.now(),
                               operate_type=OperateEnum.Login, token=access_token,
                               displayname=display_name, ip=request.remote_addr)
                db.session.add(log)
                db.session.commit()

                return response
            else:
                raise Exception("缺少access_token")

        except Exception as e:
            StructurePrint().print(e.__str__()+":" + traceback.format_exc(), "error")
            pop_list = ["id", "redirect_uri"]
            for p in pop_list:
                if p in session:
                    session.pop(p)
        return redirect(auth_default_redirect_uri)

    @staticmethod
    @bp.route("/logs", methods=["GET"])
    @swag_from(auth_log_query.Api.api_doc)
    @auth_decorator(configure.UserPermission)
    def authLog():
        '''
        登录日志
        '''
        return auth_log_query.Api().result