Flask-WTF CSRF 保護の実装を読む

Flask を使った前後端分離型の個人サイトを構築するにあたり、Flask-WTF を使って CSRF 対策を確認する必要がありました。

Flask-WTF は CSRF 保護を提供する公式拡張の一つです。以下に主なソースコードを示します。

def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
    """Check if the given data is a valid CSRF token. This compares the given
    signed token to the one stored in the session.

    :param data: The signed CSRF token to be checked.
    :param secret_key: Used to securely sign the token. Default is
        ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
    :param time_limit: Number of seconds that the token is valid. Default is
        ``WTF_CSRF_TIME_LIMIT`` or 3600 seconds (60 minutes).
    :param token_key: Key where token is stored in session for comparision.
        Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.

    :raises ValidationError: Contains the reason that validation failed.

    .. versionchanged:: 0.14
        Raises ``ValidationError`` with a specific error message rather than
        returning ``True`` or ``False``.
    """

    secret_key = _get_config(
        secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
        message='A secret key is required to use CSRF.'
    )
    field_name = _get_config(
        token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
        message='A field name is required to use CSRF.'
    )
    time_limit = _get_config(
        time_limit, 'WTF_CSRF_TIME_LIMIT', 3600, required=False
    )

    if not data:
        raise ValidationError('The CSRF token is missing.')

    if field_name not in session:
        raise ValidationError('The CSRF session token is missing.')

    s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')

    try:
        token = s.loads(data, max_age=time_limit)
    except SignatureExpired:
        raise ValidationError('The CSRF token has expired.')
    except BadData:
        raise ValidationError('The CSRF token is invalid.')

    if not safe_str_cmp(session[field_name], token):
        raise ValidationError('The CSRF tokens do not match.')


class CSRFProtect(object):
    """Enable CSRF protection globally for a Flask app.

    ::
        app = Flask(__name__)
        csrf = CsrfProtect(app)

    Checks the ``csrf_token`` field sent with forms, or the ``X-CSRFToken``
    header sent with JavaScript requests. Render the token in templates using
    ``{{ csrf_token() }}``.
    """

    def __init__(self, app=None):
        self._exempt_views = set()
        self._exempt_blueprints = set()

        if app:
            self.init_app(app)

    def init_app(self, app):
        app.extensions['csrf'] = self

        app.config.setdefault('WTF_CSRF_ENABLED', True)
        app.config.setdefault('WTF_CSRF_CHECK_DEFAULT', True)
        app.config['WTF_CSRF_METHODS'] = set(app.config.get(
            'WTF_CSRF_METHODS', ['POST', 'PUT', 'PATCH', 'DELETE']
        ))
        app.config.setdefault('WTF_CSRF_FIELD_NAME', 'csrf_token')
        app.config.setdefault(
            'WTF_CSRF_HEADERS', ['X-CSRFToken', 'X-CSRF-Token']
        )
        app.config.setdefault('WTF_CSRF_TIME_LIMIT', 3600)
        app.config.setdefault('WTF_CSRF_SSL_STRICT', True)

        app.jinja_env.globals['csrf_token'] = generate_csrf
        app.context_processor(lambda: {'csrf_token': generate_csrf})

        @app.before_request
        def csrf_protect():
            if not app.config['WTF_CSRF_ENABLED']:
                return

            if not app.config['WTF_CSRF_CHECK_DEFAULT']:
                return

            if request.method not in app.config['WTF_CSRF_METHODS']:
                return

            if not request.endpoint:
                return

            view = app.view_functions.get(request.endpoint)

            if not view:
                return

            if request.blueprint in self._exempt_blueprints:
                return

            dest = '%s.%s' % (view.__module__, view.__name__)

            if dest in self._exempt_views:
                return

            self.protect()

    def _get_csrf_token(self):
        field_name = current_app.config['WTF_CSRF_FIELD_NAME']

        for key in request.form:
            if key.endswith(field_name):
                csrf_token = request.form[key]

                if csrf_token:
                    return csrf_token

        for header_name in current_app.config['WTF_CSRF_HEADERS']:
            csrf_token = request.headers.get(header_name)

            if csrf_token:
                return csrf_token

        return None

    def protect(self):
        if request.method not in current_app.config['WTF_CSRF_METHODS']:
            return

        try:
            validate_csrf(self._get_csrf_token())
        except ValidationError as e:
            logger.info(e.args[0])
            self._error_response(e.args[0])

        if request.is_secure and current_app.config['WTF_CSRF_SSL_STRICT']:
            if not request.referrer:
                self._error_response('The referrer header is missing.')

            good_referrer = 'https://{0}/'.format(request.host)

            if not same_origin(request.referrer, good_referrer):
                self._error_response('The referrer does not match the host.')

        g.csrf_valid = True

Flask-WTF はデフォルトで Jinja2 テンプレートに csrf_token 関数を登録します。これは generate_csrf 関数を介して CSRF トークンを生成します。このため、前後端分離のアプリケーションでもこの関数を使ってトークンを取得することが可能です。

generate_csrf 関数の内部では、セッションにトークンを保存し、署名済みトークンを返します。この処理により、関数が呼ばれるたびにセッション内の CSRF トークンが変化します。トークンは有効期間内(デフォルトで1時間)であれば再利用可能ですが、セキュリティ上は毎回取得するようにすることが推奨されます。

リクエストが JSON 形式の場合は、_get_csrf_token メソッドはリクエストヘッダの X-CSRFToken からトークンを取得します。

Flask-WTF はリクエスト開始時のフック(before_request)を使って CSRF 検証を行います。以下がその処理の核心です。

@app.before_request
def csrf_protect():

このフックは Flask の before_request_funcs に追加されるため、リクエストごとに実行されます。以下のコードがフック関数の動作を制御します。

def preprocess_request(self):
    ...
    for func in funcs:
        rv = func()
        if rv is not None:
            return rv

csrf_protect 関数では、CSRF 検証が必要ない条件(例:GET メソッドや除外されたエンドポイント)をチェックし、必要に応じて検証をスキップします。検証が必要な場合は protect メソッドを呼び出し、最終的に validate_csrf 関数でトークンの検証を行います。

def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
    ...
    if not data:
        raise ValidationError('The CSRF token is missing.')

    if field_name not in session:
        raise ValidationError('The CSRF session token is missing.')

    s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')

    try:
        token = s.loads(data, max_age=time_limit)
    except SignatureExpired:
        raise ValidationError('The CSRF token has expired.')
    except BadData:
        raise ValidationError('The CSRF token is invalid.')

    if not safe_str_cmp(session[field_name], token):
        raise ValidationError('The CSRF tokens do not match.')

この関数はセッションに保存されたトークンとリクエストで送信されたトークンを比較します。一致しなければ検証に失敗します。

前後端分離のアプリケーションでは、CSRF トークンを API 経由または Cookie 経由でフロントエンドに渡す必要があります。フロントエンドはこのトークンを保存し、POST リクエスト時に X-CSRFToken ヘッダに設定して送信します。

デフォルトでは、フォームフィールド名は csrf_token、リクエストヘッダ名は X-CSRFToken となっています。

タグ: flask csrf Flask-WTF Web Security Python

7月2日 16:28 投稿