abelcastro.dev

Test inheritance with python (and Django)

2021-03-26

TestingPytestDjango

Let's image that we have this models:

models.py

class Athlete(models.Model):
    name = models.CharField(max_length=255)
    slug = models.SlugField(max_length=300, unique=True)
    age = models.PositiveIntegerField()

    class Meta:
        abstract = True


class BasketballPlayer(Athlete):
    points_scored = models.PositiveIntegerField()
    assists = models.PositiveIntegerField()
    rebounds = models.PositiveIntegerField()

    def __str__(self):
        return self.name


class SoccerPlayer(Athlete):
    goals_scored = models.PositiveIntegerField()
    assists = models.PositiveIntegerField()
    yellow_cards = models.PositiveIntegerField()

    def __str__(self):
        return self.name

We also provide these 2 GET endpoints:

  • api/athletes/basketball-player/[slug]
  • api/athletes/soccer-player/[slug]

And we want to test that every endpoint returns a 200 code and also be sure that serialized data looks like we expect.

tests.py

class TestBasketballPlayerAPI(TestCase):
    def setUp(self) -> None:
        self.basketball_player = create_test_basketball_player()

    def test__response_ok(self):
        url = reverse_lazy(
            "basketball_player", kwargs={"slug": self.basketball_player.slug}
        )
        response = self.client.get(url)

        assert response.status_code == 200
        assert response.data == BasketballPlayerSerializer(self.basketball_player).data


class TestSoccerPlayerAPI(TestCase):
    def setUp(self) -> None:
        self.soccer_player = create_test_soccer_player()

    def test__response_ok(self):
        url = reverse_lazy("soccer_player", kwargs={"slug": self.soccer_player.slug})
        response = self.client.get(url)

        assert response.status_code == 200
        assert response.data == SoccerPlayerSerializer(self.soccer_player).data

Both test classes TestBasketballPlayerAPI and TestSoccerPlayerAPI are very similar. Both check that a 200 is returned and the serialized data.

Let's try to refactor this. We could make a base class AthleteTest for the tests and inherit from it.

class AthleteTest(TestCase):
    view_name = ""
    serializer_class = None

    def setUp(self) -> None:
        self.test_athlete = None

    @mark.django_db
    def test__response_ok(self):
        url = reverse_lazy(
            self.view_name, kwargs={"slug": self.test_athlete.slug}
        )
        response = self.client.get(url)

        assert response.status_code == 200
        assert response.data == self.serializer_class(self.test_athlete).data


class TestBasketballPlayerAPI(AthleteTest):
    view_name = "basketball_player"
    serializer_class = BasketballPlayerSerializer

    def setUp(self) -> None:
        self.test_athlete = create_test_basketball_player()


class TestSoccerPlayerAPI(AthleteTest):
    view_name = "soccer_player"
    serializer_class = SoccerPlayerSerializer

    def setUp(self) -> None:
        self.test_athlete = create_test_soccer_player()

The code is now much shorter and easier to extend with new test cases. But if we try to run pytest it will fail. Pytest will collect 3 test cases and fail because AthleteTest is not actually a test case and it is only intended to be inherit from it.

In order to avoid this problem we can use the option __test__. After adding it the test will look like this:

class AthleteTest(TestCase):
    __test__ = False
    view_name = ""
    serializer_class = None

    def setUp(self) -> None:
        self.test_athlete = None

    @mark.django_db
    def test__response_ok(self):
        url = reverse_lazy(
            self.view_name, kwargs={"slug": self.test_athlete.slug}
        )
        response = self.client.get(url)
        assert response.status_code == 200
        assert response.data == self.serializer_class(self.test_athlete).data


class TestBasketballPlayerAPI(AthleteTest):
    __test__ = True
    view_name = "basketball_player"
    serializer_class = BasketballPlayerSerializer

    def setUp(self) -> None:
        self.test_athlete = create_test_basketball_player()


class TestSoccerPlayerAPI(AthleteTest):
    __test__ = True
    view_name = "soccer_player"
    serializer_class = SoccerPlayerSerializer

    def setUp(self) -> None:
        self.test_athlete = create_test_soccer_player()

Now pytest will only collect 2 test and pass as expected.

Checkout the complete source code here.

Note: this method only works with pytest and NOT with the Django test runner.