Call details#
Sometimes you need to understand which parameters have been passed to the function within the function under consideration. This page focuses on such a case.
import torch
import unittest
from unittest.mock import patch
assert_called_with
#
If you need to check which arguments were passed to the mocked function, you can use the assert_called_with(<supposed arguments>)
method of the patch object.
As an example, consider a simple function that just wraps the default sum
function.
def sum_wrapper(numbers):
return sum(numbers)
Tests will mock sum
funtion. In both tests to sum_wrapper
was passed [1,2,3]
list. But in second case we use [1,2,5]
in assert_called_with
.
class TestCalledWith(unittest.TestCase):
def test_ok(self):
with patch("__main__.sum") as mocked_sum:
sum_wrapper([1,2,3])
mocked_sum.assert_called_with([1,2,3])
def test_fail(self):
with patch("__main__.sum") as mocked_sum:
sum_wrapper([1,2,3])
mocked_sum.assert_called_with([1,2,5])
ans = unittest.main(argv=[''], verbosity=2, exit=False)
del TestCalledWith
test_fail (__main__.TestCalledWith) ... FAIL
test_ok (__main__.TestCalledWith) ... ok
======================================================================
FAIL: test_fail (__main__.TestCalledWith)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/tmp/ipykernel_196007/486826917.py", line 10, in test_fail
mocked_sum.assert_called_with([1,2,5])
File "/usr/lib/python3.10/unittest/mock.py", line 929, in assert_called_with
raise AssertionError(_error_message()) from cause
AssertionError: expected call not found.
Expected: sum([1, 2, 5])
Actual: sum([1, 2, 3])
----------------------------------------------------------------------
Ran 2 tests in 0.011s
FAILED (failures=1)
So one test passed successfully because the supposed argument in assert_called_with
matches the argument passed to sum_wrapper
. But in the second case we got “Fail” because it doesn’t.
Get arguments#
Sometimes it’s critical to just extract arguments with which mocked object was called. For example you may want to use special testing facilities provided by some specific packages like torch.testing.assert_close
.
You can achieve this by accessing the call_args
attribute of the Mock
. This will store the arguments that were passed to the model.
The following cell shows how to get the torch.tensor
that was passed to the patched function.
def sum_wrapper(numbers):
return sum(numbers)
with patch("__main__.sum_wrapper") as sw:
sum_wrapper(torch.tensor([1,2,3,4]))
call_args = sw.call_args
call_args[0]
(tensor([1, 2, 3, 4]),)
Several calls#
Consider the case where the mocked function is called many times by the unit under test. For this case you can use the mock_calls
attribute of the unittest.mock.MagicMock
object.
The following cell prints what mock_calls
will contain if sum
is called twice in sum_wrapper
.
def sum_wrapper(numbers):
sum(numbers + [3,3])
sum(iterable=(numbers + [2,2]))
with patch("__main__.sum") as mocked_sum:
sum_wrapper([1,2,3])
print(mocked_sum.mock_calls)
print(mocked_sum.mock_calls[0].args)
print(mocked_sum.mock_calls[1].kwargs)
[call([1, 2, 3, 3, 3]), call(iterable=[1, 2, 3, 2, 2])]
([1, 2, 3, 3, 3],)
{'iterable': [1, 2, 3, 2, 2]}
So for each call we list what element it contains. We can access each call and query the args
field for positional arguments and kwargs
for named arguments.
The full version of such a test should look like the following cell:
class TestCalledWith(unittest.TestCase):
def test_example(self):
with patch("__main__.sum") as mocked_sum:
sum_wrapper([1,2,3])
self.assertEqual(
mocked_sum.mock_calls[0].args[0],
[1,2,3,3,3]
)
self.assertEqual(
mocked_sum.mock_calls[1].kwargs,
{'iterable': [1, 2, 3, 2, 2]}
)
ans = unittest.main(argv=[''], verbosity=2, exit=False)
del TestCalledWith
test_example (__main__.TestCalledWith) ... ok
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
Calls count#
You can get amount of calls of the mock from call_count
attribute.
The following cell creates a mock and calls it 10 times in cycle. The result is the number ‘10’ as the value of the variable ‘call_count’.
mock = unittest.mock.Mock()
for i in range(10):
mock()
mock.call_count
10