Source code for drf_api_checker.collector

# -*- coding: utf-8 -*-
from itertools import chain

from django.db.models import ForeignKey, ManyToManyField


[docs]class ForeignKeysCollector(object): def __init__(self, using): self._visited = [] super(ForeignKeysCollector, self).__init__() def _collect(self, objs): objects = [] for obj in objs: if obj and obj not in self._visited: concrete_model = obj._meta.concrete_model obj = concrete_model.objects.get(pk=obj.pk) opts = obj._meta self._visited.append(obj) objects.append(obj) for field in chain(opts.fields, opts.local_many_to_many): if isinstance(field, ManyToManyField): target = getattr(obj, field.name).all() objects.extend(self._collect(target)) elif isinstance(field, ForeignKey): target = getattr(obj, field.name) objects.extend(self._collect([target])) return objects def collect(self, obj): if not hasattr(obj, '__iter__'): obj = [obj] self._visited = [] self.data = self._collect(obj) self.models = set([o.__class__ for o in self.data])
# # def __str__(self): # return mark_safe(self.data)