Call details#

Sometimes you need to understand which parameters have been passed to the function within the function under consideration. This page focuses on such a case.

import torch

import unittest
from unittest.mock import patch

assert_called_with#

If you need to check which arguments were passed to the mocked function, you can use the assert_called_with(<supposed arguments>) method of the patch object.


As an example, consider a simple function that just wraps the default sum function.

def sum_wrapper(numbers):
    return sum(numbers)

Tests will mock sum funtion. In both tests to sum_wrapper was passed [1,2,3] list. But in second case we use [1,2,5] in assert_called_with.

class TestCalledWith(unittest.TestCase):
    def test_ok(self):
        with patch("__main__.sum") as mocked_sum:
            sum_wrapper([1,2,3])
            mocked_sum.assert_called_with([1,2,3])

    def test_fail(self):
        with patch("__main__.sum") as mocked_sum:
            sum_wrapper([1,2,3])
            mocked_sum.assert_called_with([1,2,5])      

ans = unittest.main(argv=[''], verbosity=2, exit=False)
del TestCalledWith
test_fail (__main__.TestCalledWith) ... FAIL
test_ok (__main__.TestCalledWith) ... ok

======================================================================
FAIL: test_fail (__main__.TestCalledWith)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipykernel_196007/486826917.py", line 10, in test_fail
    mocked_sum.assert_called_with([1,2,5])
  File "/usr/lib/python3.10/unittest/mock.py", line 929, in assert_called_with
    raise AssertionError(_error_message()) from cause
AssertionError: expected call not found.
Expected: sum([1, 2, 5])
Actual: sum([1, 2, 3])

----------------------------------------------------------------------
Ran 2 tests in 0.011s

FAILED (failures=1)

So one test passed successfully because the supposed argument in assert_called_with matches the argument passed to sum_wrapper. But in the second case we got “Fail” because it doesn’t.

Get arguments#

Sometimes it’s critical to just extract arguments with which mocked object was called. For example you may want to use special testing facilities provided by some specific packages like torch.testing.assert_close.

You can achieve this by accessing the call_args attribute of the Mock. This will store the arguments that were passed to the model.


The following cell shows how to get the torch.tensor that was passed to the patched function.

def sum_wrapper(numbers):
    return sum(numbers)

with patch("__main__.sum_wrapper") as sw:
    sum_wrapper(torch.tensor([1,2,3,4]))
    call_args = sw.call_args

call_args[0]
(tensor([1, 2, 3, 4]),)

Several calls#

Consider the case where the mocked function is called many times by the unit under test. For this case you can use the mock_calls attribute of the unittest.mock.MagicMock object.


The following cell prints what mock_calls will contain if sum is called twice in sum_wrapper.

def sum_wrapper(numbers):
    sum(numbers + [3,3])
    sum(iterable=(numbers + [2,2]))

with patch("__main__.sum") as mocked_sum:
    sum_wrapper([1,2,3])
    print(mocked_sum.mock_calls)
    print(mocked_sum.mock_calls[0].args)
    print(mocked_sum.mock_calls[1].kwargs)
[call([1, 2, 3, 3, 3]), call(iterable=[1, 2, 3, 2, 2])]
([1, 2, 3, 3, 3],)
{'iterable': [1, 2, 3, 2, 2]}

So for each call we list what element it contains. We can access each call and query the args field for positional arguments and kwargs for named arguments.

The full version of such a test should look like the following cell:

class TestCalledWith(unittest.TestCase):
    def test_example(self):
        with patch("__main__.sum") as mocked_sum:
            sum_wrapper([1,2,3])
            self.assertEqual(
                mocked_sum.mock_calls[0].args[0],
                [1,2,3,3,3]
            )
            self.assertEqual(
                mocked_sum.mock_calls[1].kwargs,
                {'iterable': [1, 2, 3, 2, 2]}
            )

ans = unittest.main(argv=[''], verbosity=2, exit=False)
del TestCalledWith
test_example (__main__.TestCalledWith) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.001s

OK

Calls count#

You can get amount of calls of the mock from call_count attribute.


The following cell creates a mock and calls it 10 times in cycle. The result is the number ‘10’ as the value of the variable ‘call_count’.

mock = unittest.mock.Mock()

for i in range(10):
    mock()

mock.call_count
10