1 package org.openkilda.test;
3 import java.lang.reflect.Parameter;
4 import static org.mockito.Mockito.mock;
5 import org.junit.jupiter.api.extension.ExtensionContext;
6 import org.junit.jupiter.api.extension.ParameterContext;
7 import org.junit.jupiter.api.extension.ParameterResolutionException;
8 import org.junit.jupiter.api.extension.ParameterResolver;
9 import org.junit.jupiter.api.extension.TestInstancePostProcessor;
10 import org.junit.jupiter.api.extension.ExtensionContext.Namespace;
11 import org.junit.jupiter.api.extension.ExtensionContext.Store;
12 import org.mockito.Mock;
13 import org.mockito.MockitoAnnotations;
15 public class MockitoExtension implements TestInstancePostProcessor, ParameterResolver {
19 MockitoAnnotations.initMocks(testInstance);
24 public Object
resolve(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
25 return getMock(parameterContext.getParameter(), extensionContext);
30 public boolean supports(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {
31 return parameterContext.getParameter().isAnnotationPresent(Mock.class);
34 private Object getMock(Parameter parameter, ExtensionContext extensionContext) {
35 Class<?> mockType = parameter.getType();
36 Store mocks = extensionContext.getStore(Namespace.create(
MockitoExtension.class, mockType));
37 String mockName = getMockName(parameter);
39 if (mockName != null) {
40 return mocks.getOrComputeIfAbsent(mockName, key -> mock(mockType, mockName));
42 return mocks.getOrComputeIfAbsent(mockType.getCanonicalName(), key -> mock(mockType));
45 private String getMockName(Parameter parameter) {
46 String explicitMockName = parameter.getAnnotation(Mock.class).name().trim();
47 if (!explicitMockName.isEmpty()) {
48 return explicitMockName;
49 }
else if (parameter.isNamePresent()) {
50 return parameter.getName();
void postProcessTestInstance(Object testInstance, ExtensionContext context)
boolean supports(ParameterContext parameterContext, ExtensionContext extensionContext)
Object resolve(ParameterContext parameterContext, ExtensionContext extensionContext)