# -*- coding: utf-8 -*-
import os
from collections import OrderedDict
from django.urls import resolve
import drf_api_checker
from drf_api_checker.exceptions import (
FieldAddedError, FieldMissedError, FieldValueError, HeaderError, StatusCodeError,
DictKeyMissed, DictKeyAdded)
from drf_api_checker.fs import clean_url, get_filename
from drf_api_checker.utils import _write, load_response, serialize_response
HEADERS_TO_CHECK = ['Content-Type', 'Content-Length', 'Allow']
BASE_DATADIR = '_api_checker'
STATUS_CODE = 1
FIELDS = 2
HEADERS = 3
DEFAULT_CHECKS = [STATUS_CODE, FIELDS, HEADERS]
[docs]class Recorder:
expect_errors = False
allow_empty = False
headers = HEADERS_TO_CHECK
checks = list(DEFAULT_CHECKS)
def __init__(self, data_dir, owner=None, headers_to_check=None, fixture_file=None, as_user=None) -> None:
self.data_dir = data_dir
self.fixture_file = fixture_file or self.data_dir
self.owner = owner
self.user = as_user
self.headers_to_check = headers_to_check or self.headers
self.check_map = {FIELDS: self._assert_fields,
STATUS_CODE: self._assert_status,
HEADERS: self._assert_headers
}
if hasattr(self, 'check_headers'):
raise DeprecationWarning("'check_headers' has been deprecated. Use 'checks' instead.")
if hasattr(self, 'check_status'):
raise DeprecationWarning("'check_status' has been deprecated. Use 'checks' instead.")
@property
def client(self):
if self.owner:
client = self.owner.client
else:
from rest_framework.test import APIClient
client = APIClient()
if self.user:
client.force_authenticate(self.user)
return client
def get_response_filename(self, method, url, data):
return get_filename(self.data_dir,
clean_url(method, url, data) + '.response.json')
def _get_custom_asserter(self, path, field_name):
for attr in [f'assert_{path}_{field_name}', f'assert_{field_name}']:
for target in [self, self.owner]:
if hasattr(target, attr):
return getattr(target, attr)
return None
def _compare_dict(self, response, stored, path=None, view='unknown', filename='unknown'):
try:
self.check_dict_keys(response, stored)
except DictKeyMissed as e:
raise FieldMissedError(view, e.keys)
except DictKeyAdded as e:
raise FieldAddedError(view, e.keys, filename)
path = path or []
for field_name, field_value in response.items():
if isinstance(field_value, (dict, OrderedDict)):
path.append(field_name)
self._compare_dict(field_value, stored[field_name], path, view=view,
filename=filename)
else:
asserter = self._get_custom_asserter(path, field_name)
if asserter:
asserter(response, stored, path)
else:
if isinstance(field_value, (set, list, tuple)):
safe_field_value = list(field_value)
stored_field_value = stored[field_name]
if len(safe_field_value) != len(stored_field_value):
raise FieldValueError(view=view,
message="Field len `{0.field_name}` does not match.",
expected=stored_field_value,
received=safe_field_value,
field_name=field_name,
filename=self.fixture_file)
for i, entry in enumerate(safe_field_value):
if isinstance(entry, (dict, OrderedDict)):
entry = dict(entry)
path.append('%s[%s]' % (field_name, i))
self._compare_dict(entry, stored_field_value[i],
path, view=view,
filename=self.fixture_file)
# if entry != stored_field_value[i]:
# raise FieldValueError(view=view,
# expected=stored_field_value[i],
# received=entry,
# field_name='%s[%s]' % (field_name, i),
# filename=self.data_dir)
elif field_name in stored and field_value != stored[field_name]:
path.append(field_name)
full_path_to_field = ".".join(path)
raise FieldValueError(view=view,
expected=stored[field_name],
received=response[field_name],
field_name=full_path_to_field,
filename=self.fixture_file)
def get_single_record(self, response, expected):
if isinstance(response, (list, tuple)):
response = response[0]
expected = expected[0]
return response, expected
def check_dict_keys(self, response, expected):
_recv = set(response.keys())
_expct = set(expected.keys())
added = _recv.difference(_expct)
missed = _expct.difference(_recv)
if missed:
raise DictKeyMissed(", ".join(missed))
if added:
raise DictKeyAdded(", ".join(added))
def compare(self, response, expected, filename='unknown', ignore_fields=None, view='unknown'):
if response:
if isinstance(response, (list, tuple)):
a = response[0]
b = expected[0]
else:
a = response
b = expected
try:
self.check_dict_keys(a, b)
except DictKeyMissed as e:
raise FieldMissedError(view, e.keys)
except DictKeyAdded as e:
raise FieldAddedError(view, e.keys, filename)
response, expected = self.get_single_record(response, expected)
self._compare_dict(response, expected, view=view, filename=filename)
else:
assert response == expected
return True
def assertGET(self, url, *, allow_empty=None, targets=None,
expect_errors=None, name=None, data=None, **kwargs):
if 'check_headers' in kwargs:
raise DeprecationWarning("'check_headers' has been deprecated. Use 'targets' instead.")
if 'check_status' in kwargs:
raise DeprecationWarning("'check_status' has been deprecated. Use 'targets' instead.")
if kwargs:
raise AttributeError("Unknown arguments %s" % kwargs.keys())
return self.assertCALL(url, allow_empty=allow_empty,
targets=targets,
expect_errors=expect_errors, name=name, data=data)
def assertPUT(self, url, data, *, allow_empty=None,
expect_errors=None, name=None, targets=None, **kwargs):
if 'check_headers' in kwargs:
raise DeprecationWarning("'check_headers' has been deprecated. Use 'targets' instead.")
if 'check_status' in kwargs:
raise DeprecationWarning("'check_status' has been deprecated. Use 'targets' instead.")
if kwargs:
raise AttributeError("Unknown arguments %s" % kwargs.keys())
return self.assertCALL(url, data=data, method='put', allow_empty=allow_empty,
targets=targets,
expect_errors=expect_errors, name=name)
def assertPOST(self, url, data, *, allow_empty=None, check_headers=None, check_status=None,
expect_errors=None, name=None, checks=None, **kwargs):
if 'check_headers' in kwargs:
raise DeprecationWarning("'check_headers' has been deprecated. Use 'checks' instead.")
if 'check_status' in kwargs:
raise DeprecationWarning("'check_status' has been deprecated. Use 'checks' instead.")
if kwargs:
raise AttributeError("Unknown arguments %s" % kwargs.keys())
return self.assertCALL(url, data=data, method='post', allow_empty=allow_empty,
check_headers=check_headers, check_status=check_status,
expect_errors=expect_errors, name=name)
def assertDELETE(self, url, *, allow_empty=None,
expect_errors=None, name=None, data=None, **kwargs):
if 'check_headers' in kwargs:
raise DeprecationWarning("'check_headers' has been deprecated. Use 'checks' instead.")
if 'check_status' in kwargs:
raise DeprecationWarning("'check_status' has been deprecated. Use 'checks' instead.")
if kwargs:
raise AttributeError("Unknown arguments %s" % kwargs.keys())
return self.assertCALL(url, method='delete', allow_empty=allow_empty,
expect_errors=expect_errors, name=name, data=data)
def assertCALL(self, url, *, allow_empty=None,
expect_errors=None, name=None, method='get', data=None,
checks=None, **kwargs):
"""
check url for response changes
:param url: url to check
:param allow_empty: if True ignore empty response and 404 errors
:param checks: list and order checks. ie. `checks=[STATUS_CODE, FIELDS, HEADERS]`
:param check_status: check response status code
:raises: ValueError
:raises: AssertionError
"""
if 'check_headers' in kwargs:
raise DeprecationWarning("'check_headers' has been deprecated. Use 'checks' instead.")
if 'check_status' in kwargs:
raise DeprecationWarning("'check_status' has been deprecated. Use 'checks' instead.")
if kwargs:
raise AttributeError("Unknown arguments %s" % kwargs.keys())
expect_errors = self.expect_errors if expect_errors is None else expect_errors
allow_empty = self.allow_empty if allow_empty is None else allow_empty
self.view = resolve(url).func.cls
m = getattr(self.client, method.lower())
self.filename = self.get_response_filename(method, name or url, data)
response = m(url, data=data)
assert response.accepted_renderer
payload = response.data
if not allow_empty and not payload:
raise ValueError(f"View {self.view} returned and empty json. Check your test")
if response.status_code > 299 and not expect_errors:
raise ValueError(f"View {self.view} unexpected response. {response.status_code} - {response.content}")
if not allow_empty and response.status_code == 404:
raise ValueError(f"View {self.view} returned 404 status code. Check your test")
if not os.path.exists(self.filename) or os.environ.get('API_CHECKER_RESET', False):
_write(self.filename, serialize_response(response))
stored = load_response(self.filename)
if checks is None:
checks = self.checks
for check_id in checks:
check = self.check_map[check_id]
check(response, stored)
return response, stored
def _assert_fields(self, response, stored):
self.compare(response.data, stored.data, self.filename, view=self.view)
def _assert_status(self, response, stored):
if response.status_code != stored.status_code:
raise StatusCodeError(self.view, response.status_code, stored.status_code)
def _assert_headers(self, response, stored):
for h in self.headers_to_check:
_expected = stored.get(h)
_recv = response.get(h)
if _expected != _recv:
raise HeaderError(self.view, h, _expected,
_recv,
self.filename,
f"{stored.content}/{response.content}")
# if sorted(response.get('Allow')) != sorted(stored.get('Allow')):
# raise HeaderError(self.view, h, stored.get(h),
# response.get(h),
# self.filename)
#
# assert response.get('Content-Type') == stored.get('Content-Type')
# assert response.get('Content-Length') == stored.get('Content-Length'), response.content
# assert sorted(response.get('Allow')) == sorted(stored.get('Allow'))