Source code
Revision control
Copy as Markdown
Other Tools
from __future__ import absolute_import, print_function, division, unicode_literals
import _io
import inspect
import json as json_module
import logging
import re
import six
from collections import namedtuple
from functools import update_wrapper
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from requests.sessions import REDIRECT_STATI
from requests.utils import cookiejar_from_dict
try:
from collections.abc import Sequence, Sized
except ImportError:
from collections import Sequence, Sized
try:
from requests.packages.urllib3.response import HTTPResponse
except ImportError:
from urllib3.response import HTTPResponse
if six.PY2:
from urlparse import urlparse, parse_qsl, urlsplit, urlunsplit
from urllib import quote
else:
from urllib.parse import urlparse, parse_qsl, urlsplit, urlunsplit, quote
if six.PY2:
try:
from six import cStringIO as BufferIO
except ImportError:
from six import StringIO as BufferIO
else:
from io import BytesIO as BufferIO
try:
from unittest import mock as std_mock
except ImportError:
import mock as std_mock
try:
Pattern = re._pattern_type
except AttributeError:
# Python 3.7
Pattern = re.Pattern
UNSET = object()
Call = namedtuple("Call", ["request", "response"])
_real_send = HTTPAdapter.send
logger = logging.getLogger("responses")
def _is_string(s):
return isinstance(s, six.string_types)
def _has_unicode(s):
return any(ord(char) > 128 for char in s)
def _clean_unicode(url):
# Clean up domain names, which use punycode to handle unicode chars
urllist = list(urlsplit(url))
netloc = urllist[1]
if _has_unicode(netloc):
domains = netloc.split(".")
for i, d in enumerate(domains):
if _has_unicode(d):
d = "xn--" + d.encode("punycode").decode("ascii")
domains[i] = d
urllist[1] = ".".join(domains)
url = urlunsplit(urllist)
# Clean up path/query/params, which use url-encoding to handle unicode chars
if isinstance(url.encode("utf8"), six.string_types):
url = url.encode("utf8")
chars = list(url)
for i, x in enumerate(chars):
if ord(x) > 128:
chars[i] = quote(x)
return "".join(chars)
def _is_redirect(response):
try:
# 2.0.0 <= requests <= 2.2
return response.is_redirect
except AttributeError:
# requests > 2.2
return (
# use request.sessions conditional
response.status_code in REDIRECT_STATI
and "location" in response.headers
)
def _cookies_from_headers(headers):
try:
import http.cookies as cookies
resp_cookie = cookies.SimpleCookie()
resp_cookie.load(headers["set-cookie"])
cookies_dict = {name: v.value for name, v in resp_cookie.items()}
except ImportError:
from cookies import Cookies
resp_cookies = Cookies.from_request(headers["set-cookie"])
cookies_dict = {v.name: v.value for _, v in resp_cookies.items()}
return cookiejar_from_dict(cookies_dict)
_wrapper_template = """\
def wrapper%(wrapper_args)s:
with responses:
return func%(func_args)s
"""
def get_wrapped(func, responses):
if six.PY2:
args, a, kw, defaults = inspect.getargspec(func)
wrapper_args = inspect.formatargspec(args, a, kw, defaults)
# Preserve the argspec for the wrapped function so that testing
# tools such as pytest can continue to use their fixture injection.
if hasattr(func, "__self__"):
args = args[1:] # Omit 'self'
func_args = inspect.formatargspec(args, a, kw, None)
else:
signature = inspect.signature(func)
signature = signature.replace(return_annotation=inspect.Signature.empty)
# If the function is wrapped, switch to *args, **kwargs for the parameters
# as we can't rely on the signature to give us the arguments the function will
# be called with. For example unittest.mock.patch uses required args that are
# not actually passed to the function when invoked.
if hasattr(func, "__wrapped__"):
wrapper_params = [
inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL),
inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD),
]
else:
wrapper_params = [
param.replace(annotation=inspect.Parameter.empty)
for param in signature.parameters.values()
]
signature = signature.replace(parameters=wrapper_params)
wrapper_args = str(signature)
params_without_defaults = [
param.replace(
annotation=inspect.Parameter.empty, default=inspect.Parameter.empty
)
for param in signature.parameters.values()
]
signature = signature.replace(parameters=params_without_defaults)
func_args = str(signature)
evaldict = {"func": func, "responses": responses}
six.exec_(
_wrapper_template % {"wrapper_args": wrapper_args, "func_args": func_args},
evaldict,
)
wrapper = evaldict["wrapper"]
update_wrapper(wrapper, func)
return wrapper
class CallList(Sequence, Sized):
def __init__(self):
self._calls = []
def __iter__(self):
return iter(self._calls)
def __len__(self):
return len(self._calls)
def __getitem__(self, idx):
return self._calls[idx]
def add(self, request, response):
self._calls.append(Call(request, response))
def reset(self):
self._calls = []
def _ensure_url_default_path(url):
if _is_string(url):
url_parts = list(urlsplit(url))
if url_parts[2] == "":
url_parts[2] = "/"
url = urlunsplit(url_parts)
return url
def _handle_body(body):
if isinstance(body, six.text_type):
body = body.encode("utf-8")
if isinstance(body, _io.BufferedReader):
return body
return BufferIO(body)
_unspecified = object()
class BaseResponse(object):
content_type = None
headers = None
stream = False
def __init__(self, method, url, match_querystring=_unspecified):
self.method = method
# ensure the url has a default path set if the url is a string
self.url = _ensure_url_default_path(url)
self.match_querystring = self._should_match_querystring(match_querystring)
self.call_count = 0
def __eq__(self, other):
if not isinstance(other, BaseResponse):
return False
if self.method != other.method:
return False
# Can't simply do a equality check on the objects directly here since __eq__ isn't
# implemented for regex. It might seem to work as regex is using a cache to return
# the same regex instances, but it doesn't in all cases.
self_url = self.url.pattern if isinstance(self.url, Pattern) else self.url
other_url = other.url.pattern if isinstance(other.url, Pattern) else other.url
return self_url == other_url
def __ne__(self, other):
return not self.__eq__(other)
def _url_matches_strict(self, url, other):
url_parsed = urlparse(url)
other_parsed = urlparse(other)
if url_parsed[:3] != other_parsed[:3]:
return False
url_qsl = sorted(parse_qsl(url_parsed.query))
other_qsl = sorted(parse_qsl(other_parsed.query))
if len(url_qsl) != len(other_qsl):
return False
for (a_k, a_v), (b_k, b_v) in zip(url_qsl, other_qsl):
if a_k != b_k:
return False
if a_v != b_v:
return False
return True
def _should_match_querystring(self, match_querystring_argument):
if match_querystring_argument is not _unspecified:
return match_querystring_argument
if isinstance(self.url, Pattern):
# the old default from <= 0.9.0
return False
return bool(urlparse(self.url).query)
def _url_matches(self, url, other, match_querystring=False):
if _is_string(url):
if _has_unicode(url):
url = _clean_unicode(url)
if not isinstance(other, six.text_type):
other = other.encode("ascii").decode("utf8")
if match_querystring:
return self._url_matches_strict(url, other)
else:
url_without_qs = url.split("?", 1)[0]
other_without_qs = other.split("?", 1)[0]
return url_without_qs == other_without_qs
elif isinstance(url, Pattern) and url.match(other):
return True
else:
return False
def get_headers(self):
headers = {}
if self.content_type is not None:
headers["Content-Type"] = self.content_type
if self.headers:
headers.update(self.headers)
return headers
def get_response(self, request):
raise NotImplementedError
def matches(self, request):
if request.method != self.method:
return False
if not self._url_matches(self.url, request.url, self.match_querystring):
return False
return True
class Response(BaseResponse):
def __init__(
self,
method,
url,
body="",
json=None,
status=200,
headers=None,
stream=False,
content_type=UNSET,
**kwargs
):
# if we were passed a `json` argument,
# override the body and content_type
if json is not None:
assert not body
body = json_module.dumps(json)
if content_type is UNSET:
content_type = "application/json"
if content_type is UNSET:
content_type = "text/plain"
# body must be bytes
if isinstance(body, six.text_type):
body = body.encode("utf-8")
self.body = body
self.status = status
self.headers = headers
self.stream = stream
self.content_type = content_type
super(Response, self).__init__(method, url, **kwargs)
def get_response(self, request):
if self.body and isinstance(self.body, Exception):
raise self.body
headers = self.get_headers()
status = self.status
body = _handle_body(self.body)
return HTTPResponse(
status=status,
reason=six.moves.http_client.responses.get(status),
body=body,
headers=headers,
preload_content=False,
)
class CallbackResponse(BaseResponse):
def __init__(
self, method, url, callback, stream=False, content_type="text/plain", **kwargs
):
self.callback = callback
self.stream = stream
self.content_type = content_type
super(CallbackResponse, self).__init__(method, url, **kwargs)
def get_response(self, request):
headers = self.get_headers()
result = self.callback(request)
if isinstance(result, Exception):
raise result
status, r_headers, body = result
if isinstance(body, Exception):
raise body
body = _handle_body(body)
headers.update(r_headers)
return HTTPResponse(
status=status,
reason=six.moves.http_client.responses.get(status),
body=body,
headers=headers,
preload_content=False,
)
class RequestsMock(object):
DELETE = "DELETE"
GET = "GET"
HEAD = "HEAD"
OPTIONS = "OPTIONS"
PATCH = "PATCH"
POST = "POST"
PUT = "PUT"
response_callback = None
def __init__(
self,
assert_all_requests_are_fired=True,
response_callback=None,
passthru_prefixes=(),
target="requests.adapters.HTTPAdapter.send",
):
self._calls = CallList()
self.reset()
self.assert_all_requests_are_fired = assert_all_requests_are_fired
self.response_callback = response_callback
self.passthru_prefixes = tuple(passthru_prefixes)
self.target = target
def reset(self):
self._matches = []
self._calls.reset()
def add(
self,
method=None, # method or ``Response``
url=None,
body="",
adding_headers=None,
*args,
**kwargs
):
"""
A basic request:
You can also directly pass an object which implements the
``BaseResponse`` interface:
>>> responses.add(Response(...))
A JSON payload:
>>> responses.add(
>>> method='GET',
>>> json={'foo': 'bar'},
>>> )
Custom headers:
>>> responses.add(
>>> method='GET',
>>> headers={'X-Header': 'foo'},
>>> )
Strict query string matching:
>>> responses.add(
>>> method='GET',
>>> match_querystring=True
>>> )
"""
if isinstance(method, BaseResponse):
self._matches.append(method)
return
if adding_headers is not None:
kwargs.setdefault("headers", adding_headers)
self._matches.append(Response(method=method, url=url, body=body, **kwargs))
def add_passthru(self, prefix):
"""
Register a URL prefix to passthru any non-matching mock requests to.
mocks for the remainder, you would add the prefix as so:
"""
if _has_unicode(prefix):
prefix = _clean_unicode(prefix)
self.passthru_prefixes += (prefix,)
def remove(self, method_or_response=None, url=None):
"""
Removes a response previously added using ``add()``, identified
either by a response object inheriting ``BaseResponse`` or
``method`` and ``url``. Removes all matching responses.
"""
if isinstance(method_or_response, BaseResponse):
response = method_or_response
else:
response = BaseResponse(method=method_or_response, url=url)
while response in self._matches:
self._matches.remove(response)
def replace(self, method_or_response=None, url=None, body="", *args, **kwargs):
"""
Replaces a response previously added using ``add()``. The signature
is identical to ``add()``. The response is identified using ``method``
and ``url``, and the first matching response is replaced.
"""
if isinstance(method_or_response, BaseResponse):
response = method_or_response
else:
response = Response(method=method_or_response, url=url, body=body, **kwargs)
index = self._matches.index(response)
self._matches[index] = response
def add_callback(
self, method, url, callback, match_querystring=False, content_type="text/plain"
):
# ensure the url has a default path set if the url is a string
# url = _ensure_url_default_path(url, match_querystring)
self._matches.append(
CallbackResponse(
url=url,
method=method,
callback=callback,
content_type=content_type,
match_querystring=match_querystring,
)
)
@property
def calls(self):
return self._calls
def __enter__(self):
self.start()
return self
def __exit__(self, type, value, traceback):
success = type is None
self.stop(allow_assert=success)
self.reset()
return success
def activate(self, func):
return get_wrapped(func, self)
def _find_match(self, request):
found = None
found_match = None
for i, match in enumerate(self._matches):
if match.matches(request):
if found is None:
found = i
found_match = match
else:
# Multiple matches found. Remove & return the first match.
return self._matches.pop(found)
return found_match
def _on_request(self, adapter, request, **kwargs):
match = self._find_match(request)
resp_callback = self.response_callback
if match is None:
if request.url.startswith(self.passthru_prefixes):
logger.info("request.allowed-passthru", extra={"url": request.url})
return _real_send(adapter, request, **kwargs)
error_msg = (
"Connection refused by Responses: {0} {1} doesn't "
"match Responses Mock".format(request.method, request.url)
)
response = ConnectionError(error_msg)
response.request = request
self._calls.add(request, response)
response = resp_callback(response) if resp_callback else response
raise response
try:
response = adapter.build_response(request, match.get_response(request))
except Exception as response:
match.call_count += 1
self._calls.add(request, response)
response = resp_callback(response) if resp_callback else response
raise
if not match.stream:
response.content # NOQA
try:
response.cookies = _cookies_from_headers(response.headers)
except (KeyError, TypeError):
pass
response = resp_callback(response) if resp_callback else response
match.call_count += 1
self._calls.add(request, response)
return response
def start(self):
def unbound_on_send(adapter, request, *a, **kwargs):
return self._on_request(adapter, request, *a, **kwargs)
self._patcher = std_mock.patch(target=self.target, new=unbound_on_send)
self._patcher.start()
def stop(self, allow_assert=True):
self._patcher.stop()
if not self.assert_all_requests_are_fired:
return
if not allow_assert:
return
not_called = [m for m in self._matches if m.call_count == 0]
if not_called:
raise AssertionError(
"Not all requests have been executed {0!r}".format(
[(match.method, match.url) for match in not_called]
)
)
# expose default mock namespace
mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False)
__all__ = ["CallbackResponse", "Response", "RequestsMock"]
for __attr in (a for a in dir(_default_mock) if not a.startswith("_")):
__all__.append(__attr)
globals()[__attr] = getattr(_default_mock, __attr)