135 lines
5.0 KiB
Python
135 lines
5.0 KiB
Python
|
import sys
|
||
|
import unittest
|
||
|
from contextlib import contextmanager
|
||
|
|
||
|
from django.test import LiveServerTestCase, tag
|
||
|
from django.utils.functional import classproperty
|
||
|
from django.utils.module_loading import import_string
|
||
|
from django.utils.text import capfirst
|
||
|
|
||
|
|
||
|
class SeleniumTestCaseBase(type(LiveServerTestCase)):
|
||
|
# List of browsers to dynamically create test classes for.
|
||
|
browsers = []
|
||
|
# A selenium hub URL to test against.
|
||
|
selenium_hub = None
|
||
|
# The external host Selenium Hub can reach.
|
||
|
external_host = None
|
||
|
# Sentinel value to differentiate browser-specific instances.
|
||
|
browser = None
|
||
|
# Run browsers in headless mode.
|
||
|
headless = False
|
||
|
|
||
|
def __new__(cls, name, bases, attrs):
|
||
|
"""
|
||
|
Dynamically create new classes and add them to the test module when
|
||
|
multiple browsers specs are provided (e.g. --selenium=firefox,chrome).
|
||
|
"""
|
||
|
test_class = super().__new__(cls, name, bases, attrs)
|
||
|
# If the test class is either browser-specific or a test base, return it.
|
||
|
if test_class.browser or not any(
|
||
|
name.startswith("test") and callable(value) for name, value in attrs.items()
|
||
|
):
|
||
|
return test_class
|
||
|
elif test_class.browsers:
|
||
|
# Reuse the created test class to make it browser-specific.
|
||
|
# We can't rename it to include the browser name or create a
|
||
|
# subclass like we do with the remaining browsers as it would
|
||
|
# either duplicate tests or prevent pickling of its instances.
|
||
|
first_browser = test_class.browsers[0]
|
||
|
test_class.browser = first_browser
|
||
|
# Listen on an external interface if using a selenium hub.
|
||
|
host = test_class.host if not test_class.selenium_hub else "0.0.0.0"
|
||
|
test_class.host = host
|
||
|
test_class.external_host = cls.external_host
|
||
|
# Create subclasses for each of the remaining browsers and expose
|
||
|
# them through the test's module namespace.
|
||
|
module = sys.modules[test_class.__module__]
|
||
|
for browser in test_class.browsers[1:]:
|
||
|
browser_test_class = cls.__new__(
|
||
|
cls,
|
||
|
"%s%s" % (capfirst(browser), name),
|
||
|
(test_class,),
|
||
|
{
|
||
|
"browser": browser,
|
||
|
"host": host,
|
||
|
"external_host": cls.external_host,
|
||
|
"__module__": test_class.__module__,
|
||
|
},
|
||
|
)
|
||
|
setattr(module, browser_test_class.__name__, browser_test_class)
|
||
|
return test_class
|
||
|
# If no browsers were specified, skip this class (it'll still be discovered).
|
||
|
return unittest.skip("No browsers specified.")(test_class)
|
||
|
|
||
|
@classmethod
|
||
|
def import_webdriver(cls, browser):
|
||
|
return import_string("selenium.webdriver.%s.webdriver.WebDriver" % browser)
|
||
|
|
||
|
@classmethod
|
||
|
def import_options(cls, browser):
|
||
|
return import_string("selenium.webdriver.%s.options.Options" % browser)
|
||
|
|
||
|
@classmethod
|
||
|
def get_capability(cls, browser):
|
||
|
from selenium.webdriver.common.desired_capabilities import DesiredCapabilities
|
||
|
|
||
|
return getattr(DesiredCapabilities, browser.upper())
|
||
|
|
||
|
def create_options(self):
|
||
|
options = self.import_options(self.browser)()
|
||
|
if self.headless:
|
||
|
try:
|
||
|
options.headless = True
|
||
|
except AttributeError:
|
||
|
pass # Only Chrome and Firefox support the headless mode.
|
||
|
return options
|
||
|
|
||
|
def create_webdriver(self):
|
||
|
if self.selenium_hub:
|
||
|
from selenium import webdriver
|
||
|
|
||
|
return webdriver.Remote(
|
||
|
command_executor=self.selenium_hub,
|
||
|
desired_capabilities=self.get_capability(self.browser),
|
||
|
)
|
||
|
return self.import_webdriver(self.browser)(options=self.create_options())
|
||
|
|
||
|
|
||
|
@tag("selenium")
|
||
|
class SeleniumTestCase(LiveServerTestCase, metaclass=SeleniumTestCaseBase):
|
||
|
implicit_wait = 10
|
||
|
external_host = None
|
||
|
|
||
|
@classproperty
|
||
|
def live_server_url(cls):
|
||
|
return "http://%s:%s" % (cls.external_host or cls.host, cls.server_thread.port)
|
||
|
|
||
|
@classproperty
|
||
|
def allowed_host(cls):
|
||
|
return cls.external_host or cls.host
|
||
|
|
||
|
@classmethod
|
||
|
def setUpClass(cls):
|
||
|
cls.selenium = cls.create_webdriver()
|
||
|
cls.selenium.implicitly_wait(cls.implicit_wait)
|
||
|
super().setUpClass()
|
||
|
|
||
|
@classmethod
|
||
|
def _tearDownClassInternal(cls):
|
||
|
# quit() the WebDriver before attempting to terminate and join the
|
||
|
# single-threaded LiveServerThread to avoid a dead lock if the browser
|
||
|
# kept a connection alive.
|
||
|
if hasattr(cls, "selenium"):
|
||
|
cls.selenium.quit()
|
||
|
super()._tearDownClassInternal()
|
||
|
|
||
|
@contextmanager
|
||
|
def disable_implicit_wait(self):
|
||
|
"""Disable the default implicit wait."""
|
||
|
self.selenium.implicitly_wait(0)
|
||
|
try:
|
||
|
yield
|
||
|
finally:
|
||
|
self.selenium.implicitly_wait(self.implicit_wait)
|