3475 lines
120 KiB
Python
3475 lines
120 KiB
Python
# Copyright (c) Twisted Matrix Laboratories.
|
|
# See LICENSE for details.
|
|
|
|
"""
|
|
Tests for L{twisted.web.client.Agent} and related new client APIs.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import zlib
|
|
from http.cookiejar import CookieJar
|
|
from io import BytesIO
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
|
from unittest import SkipTest, skipIf
|
|
|
|
from zope.interface.declarations import implementer
|
|
from zope.interface.verify import verifyObject
|
|
|
|
from incremental import Version
|
|
|
|
from twisted.internet import defer, task
|
|
from twisted.internet.address import IPv4Address, IPv6Address
|
|
from twisted.internet.defer import CancelledError, Deferred, succeed
|
|
from twisted.internet.endpoints import HostnameEndpoint, TCP4ClientEndpoint
|
|
from twisted.internet.error import (
|
|
ConnectionDone,
|
|
ConnectionLost,
|
|
ConnectionRefusedError,
|
|
)
|
|
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
|
|
from twisted.internet.protocol import Factory, Protocol
|
|
from twisted.internet.task import Clock
|
|
from twisted.internet.test.test_endpoints import deterministicResolvingReactor
|
|
from twisted.internet.testing import (
|
|
AccumulatingProtocol,
|
|
EventLoggingObserver,
|
|
MemoryReactorClock,
|
|
StringTransport,
|
|
)
|
|
from twisted.logger import globalLogPublisher
|
|
from twisted.python.components import proxyForInterface
|
|
from twisted.python.deprecate import getDeprecationWarningString
|
|
from twisted.python.failure import Failure
|
|
from twisted.test.iosim import FakeTransport, IOPump
|
|
from twisted.test.test_sslverify import certificatesForAuthorityAndServer
|
|
from twisted.trial.unittest import SynchronousTestCase, TestCase
|
|
from twisted.web import client, error, http_headers
|
|
from twisted.web._newclient import (
|
|
HTTP11ClientProtocol,
|
|
PotentialDataLoss,
|
|
RequestNotSent,
|
|
RequestTransmissionFailed,
|
|
Response,
|
|
ResponseFailed,
|
|
ResponseNeverReceived,
|
|
)
|
|
from twisted.web.client import (
|
|
URI,
|
|
BrowserLikePolicyForHTTPS,
|
|
FileBodyProducer,
|
|
HostnameCachingHTTPSPolicy,
|
|
HTTPConnectionPool,
|
|
Request,
|
|
ResponseDone,
|
|
_HTTP11ClientFactory,
|
|
)
|
|
from twisted.web.error import SchemeNotSupported
|
|
from twisted.web.http_headers import Headers
|
|
from twisted.web.iweb import (
|
|
UNKNOWN_LENGTH,
|
|
IAgent,
|
|
IAgentEndpointFactory,
|
|
IBodyProducer,
|
|
IPolicyForHTTPS,
|
|
IResponse,
|
|
)
|
|
from twisted.web.test.injectionhelpers import (
|
|
MethodInjectionTestsMixin,
|
|
URIInjectionTestsMixin,
|
|
)
|
|
|
|
# Creatively lie to mypy about the nature of inheritance, since dealing with
|
|
# expectations of a mixin class is basically impossible (don't use mixins).
|
|
if TYPE_CHECKING:
|
|
testMixinClass = TestCase
|
|
runtimeTestCase = object
|
|
else:
|
|
testMixinClass = object
|
|
runtimeTestCase = TestCase
|
|
|
|
try:
|
|
from twisted.internet import ssl as _ssl
|
|
except ImportError:
|
|
ssl = None
|
|
sslPresent = False
|
|
else:
|
|
ssl = _ssl
|
|
sslPresent = True
|
|
from twisted.internet._sslverify import ClientTLSOptions, IOpenSSLTrustRoot
|
|
from twisted.internet.ssl import optionsForClientTLS
|
|
from twisted.protocols import tls
|
|
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
|
|
|
|
@implementer(IOpenSSLTrustRoot)
|
|
class CustomOpenSSLTrustRoot:
|
|
called = False
|
|
context = None
|
|
|
|
def _addCACertsToContext(self, context):
|
|
self.called = True
|
|
self.context = context
|
|
|
|
|
|
class StubHTTPProtocol(Protocol):
|
|
"""
|
|
A protocol like L{HTTP11ClientProtocol} but which does not actually know
|
|
HTTP/1.1 and only collects requests in a list.
|
|
|
|
@ivar requests: A C{list} of two-tuples. Each time a request is made, a
|
|
tuple consisting of the request and the L{Deferred} returned from the
|
|
request method is appended to this list.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.requests: List[Tuple[Request, Deferred[IResponse]]] = []
|
|
self.state = "QUIESCENT"
|
|
|
|
def request(self, request):
|
|
"""
|
|
Capture the given request for later inspection.
|
|
|
|
@return: A L{Deferred} which this code will never fire.
|
|
"""
|
|
result = Deferred()
|
|
self.requests.append((request, result))
|
|
return result
|
|
|
|
|
|
class FileConsumer:
|
|
def __init__(self, outputFile):
|
|
self.outputFile = outputFile
|
|
|
|
def write(self, bytes):
|
|
self.outputFile.write(bytes)
|
|
|
|
|
|
class FileBodyProducerTests(TestCase):
|
|
"""
|
|
Tests for the L{FileBodyProducer} which reads bytes from a file and writes
|
|
them to an L{IConsumer}.
|
|
"""
|
|
|
|
def _termination(self):
|
|
"""
|
|
This method can be used as the C{terminationPredicateFactory} for a
|
|
L{Cooperator}. It returns a predicate which immediately returns
|
|
C{False}, indicating that no more work should be done this iteration.
|
|
This has the result of only allowing one iteration of a cooperative
|
|
task to be run per L{Cooperator} iteration.
|
|
"""
|
|
return lambda: True
|
|
|
|
def setUp(self):
|
|
"""
|
|
Create a L{Cooperator} hooked up to an easily controlled, deterministic
|
|
scheduler to use with L{FileBodyProducer}.
|
|
"""
|
|
self._scheduled = []
|
|
self.cooperator = task.Cooperator(self._termination, self._scheduled.append)
|
|
|
|
def test_interface(self):
|
|
"""
|
|
L{FileBodyProducer} instances provide L{IBodyProducer}.
|
|
"""
|
|
self.assertTrue(verifyObject(IBodyProducer, FileBodyProducer(BytesIO(b""))))
|
|
|
|
def test_unknownLength(self):
|
|
"""
|
|
If the L{FileBodyProducer} is constructed with a file-like object
|
|
without either a C{seek} or C{tell} method, its C{length} attribute is
|
|
set to C{UNKNOWN_LENGTH}.
|
|
"""
|
|
|
|
class HasSeek:
|
|
def seek(self, offset, whence):
|
|
pass
|
|
|
|
class HasTell:
|
|
def tell(self):
|
|
pass
|
|
|
|
producer = FileBodyProducer(HasSeek())
|
|
self.assertEqual(UNKNOWN_LENGTH, producer.length)
|
|
producer = FileBodyProducer(HasTell())
|
|
self.assertEqual(UNKNOWN_LENGTH, producer.length)
|
|
|
|
def test_knownLength(self):
|
|
"""
|
|
If the L{FileBodyProducer} is constructed with a file-like object with
|
|
both C{seek} and C{tell} methods, its C{length} attribute is set to the
|
|
size of the file as determined by those methods.
|
|
"""
|
|
inputBytes = b"here are some bytes"
|
|
inputFile = BytesIO(inputBytes)
|
|
inputFile.seek(5)
|
|
producer = FileBodyProducer(inputFile)
|
|
self.assertEqual(len(inputBytes) - 5, producer.length)
|
|
self.assertEqual(inputFile.tell(), 5)
|
|
|
|
def test_defaultCooperator(self):
|
|
"""
|
|
If no L{Cooperator} instance is passed to L{FileBodyProducer}, the
|
|
global cooperator is used.
|
|
"""
|
|
producer = FileBodyProducer(BytesIO(b""))
|
|
self.assertEqual(task.cooperate, producer._cooperate)
|
|
|
|
def test_startProducing(self):
|
|
"""
|
|
L{FileBodyProducer.startProducing} starts writing bytes from the input
|
|
file to the given L{IConsumer} and returns a L{Deferred} which fires
|
|
when they have all been written.
|
|
"""
|
|
expectedResult = b"hello, world"
|
|
readSize = 3
|
|
output = BytesIO()
|
|
consumer = FileConsumer(output)
|
|
producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
|
|
complete = producer.startProducing(consumer)
|
|
for i in range(len(expectedResult) // readSize + 1):
|
|
self._scheduled.pop(0)()
|
|
self.assertEqual([], self._scheduled)
|
|
self.assertEqual(expectedResult, output.getvalue())
|
|
self.assertEqual(None, self.successResultOf(complete))
|
|
|
|
def test_inputClosedAtEOF(self):
|
|
"""
|
|
When L{FileBodyProducer} reaches end-of-file on the input file given to
|
|
it, the input file is closed.
|
|
"""
|
|
readSize = 4
|
|
inputBytes = b"some friendly bytes"
|
|
inputFile = BytesIO(inputBytes)
|
|
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
|
|
consumer = FileConsumer(BytesIO())
|
|
producer.startProducing(consumer)
|
|
for i in range(len(inputBytes) // readSize + 2):
|
|
self._scheduled.pop(0)()
|
|
self.assertTrue(inputFile.closed)
|
|
|
|
def test_failedReadWhileProducing(self):
|
|
"""
|
|
If a read from the input file fails while producing bytes to the
|
|
consumer, the L{Deferred} returned by
|
|
L{FileBodyProducer.startProducing} fires with a L{Failure} wrapping
|
|
that exception.
|
|
"""
|
|
|
|
class BrokenFile:
|
|
def read(self, count):
|
|
raise OSError("Simulated bad thing")
|
|
|
|
producer = FileBodyProducer(BrokenFile(), self.cooperator)
|
|
complete = producer.startProducing(FileConsumer(BytesIO()))
|
|
self._scheduled.pop(0)()
|
|
self.failureResultOf(complete).trap(IOError)
|
|
|
|
def test_cancelWhileProducing(self):
|
|
"""
|
|
When the L{Deferred} returned by L{FileBodyProducer.startProducing} is
|
|
cancelled, the input file is closed and the task is stopped.
|
|
"""
|
|
expectedResult = b"hello, world"
|
|
readSize = 3
|
|
output = BytesIO()
|
|
consumer = FileConsumer(output)
|
|
inputFile = BytesIO(expectedResult)
|
|
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
|
|
complete = producer.startProducing(consumer)
|
|
complete.cancel()
|
|
self.assertTrue(inputFile.closed)
|
|
self._scheduled.pop(0)()
|
|
self.assertEqual(b"", output.getvalue())
|
|
self.assertNoResult(complete)
|
|
|
|
def test_stopProducing(self):
|
|
"""
|
|
L{FileBodyProducer.stopProducing} stops the underlying L{IPullProducer}
|
|
and the cooperative task responsible for calling C{resumeProducing} and
|
|
closes the input file but does not cause the L{Deferred} returned by
|
|
C{startProducing} to fire.
|
|
"""
|
|
expectedResult = b"hello, world"
|
|
readSize = 3
|
|
output = BytesIO()
|
|
consumer = FileConsumer(output)
|
|
inputFile = BytesIO(expectedResult)
|
|
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
|
|
complete = producer.startProducing(consumer)
|
|
producer.stopProducing()
|
|
self.assertTrue(inputFile.closed)
|
|
self._scheduled.pop(0)()
|
|
self.assertEqual(b"", output.getvalue())
|
|
self.assertNoResult(complete)
|
|
|
|
def test_pauseProducing(self):
|
|
"""
|
|
L{FileBodyProducer.pauseProducing} temporarily suspends writing bytes
|
|
from the input file to the given L{IConsumer}.
|
|
"""
|
|
expectedResult = b"hello, world"
|
|
readSize = 5
|
|
output = BytesIO()
|
|
consumer = FileConsumer(output)
|
|
producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
|
|
complete = producer.startProducing(consumer)
|
|
self._scheduled.pop(0)()
|
|
self.assertEqual(output.getvalue(), expectedResult[:5])
|
|
producer.pauseProducing()
|
|
|
|
# Sort of depends on an implementation detail of Cooperator: even
|
|
# though the only task is paused, there's still a scheduled call. If
|
|
# this were to go away because Cooperator became smart enough to cancel
|
|
# this call in this case, that would be fine.
|
|
self._scheduled.pop(0)()
|
|
|
|
# Since the producer is paused, no new data should be here.
|
|
self.assertEqual(output.getvalue(), expectedResult[:5])
|
|
self.assertEqual([], self._scheduled)
|
|
self.assertNoResult(complete)
|
|
|
|
def test_resumeProducing(self):
|
|
"""
|
|
L{FileBodyProducer.resumeProducing} re-commences writing bytes from the
|
|
input file to the given L{IConsumer} after it was previously paused
|
|
with L{FileBodyProducer.pauseProducing}.
|
|
"""
|
|
expectedResult = b"hello, world"
|
|
readSize = 5
|
|
output = BytesIO()
|
|
consumer = FileConsumer(output)
|
|
producer = FileBodyProducer(BytesIO(expectedResult), self.cooperator, readSize)
|
|
producer.startProducing(consumer)
|
|
self._scheduled.pop(0)()
|
|
self.assertEqual(expectedResult[:readSize], output.getvalue())
|
|
producer.pauseProducing()
|
|
producer.resumeProducing()
|
|
self._scheduled.pop(0)()
|
|
self.assertEqual(expectedResult[: readSize * 2], output.getvalue())
|
|
|
|
def test_multipleStop(self):
|
|
"""
|
|
L{FileBodyProducer.stopProducing} can be called more than once without
|
|
raising an exception.
|
|
"""
|
|
expectedResult = b"test"
|
|
readSize = 3
|
|
output = BytesIO()
|
|
consumer = FileConsumer(output)
|
|
inputFile = BytesIO(expectedResult)
|
|
producer = FileBodyProducer(inputFile, self.cooperator, readSize)
|
|
complete = producer.startProducing(consumer)
|
|
producer.stopProducing()
|
|
producer.stopProducing()
|
|
self.assertTrue(inputFile.closed)
|
|
self._scheduled.pop(0)()
|
|
self.assertEqual(b"", output.getvalue())
|
|
self.assertNoResult(complete)
|
|
|
|
|
|
EXAMPLE_COM_IP = "127.0.0.7"
|
|
EXAMPLE_COM_V6_IP = "::7"
|
|
EXAMPLE_NET_IP = "127.0.0.8"
|
|
EXAMPLE_ORG_IP = "127.0.0.9"
|
|
FOO_LOCAL_IP = "127.0.0.10"
|
|
FOO_COM_IP = "127.0.0.11"
|
|
|
|
|
|
class FakeReactorAndConnectMixin:
|
|
"""
|
|
A test mixin providing a testable C{Reactor} class and a dummy C{connect}
|
|
method which allows instances to pretend to be endpoints.
|
|
"""
|
|
|
|
def createReactor(self):
|
|
"""
|
|
Create a L{MemoryReactorClock} and give it some hostnames it can
|
|
resolve.
|
|
|
|
@return: a L{MemoryReactorClock}-like object with a slightly limited
|
|
interface (only C{advance} and C{tcpClients} in addition to its
|
|
formally-declared reactor interfaces), which can resolve a fixed
|
|
set of domains.
|
|
"""
|
|
mrc = MemoryReactorClock()
|
|
drr = deterministicResolvingReactor(
|
|
mrc,
|
|
hostMap={
|
|
"example.com": [EXAMPLE_COM_IP],
|
|
"ipv6.example.com": [EXAMPLE_COM_V6_IP],
|
|
"example.net": [EXAMPLE_NET_IP],
|
|
"example.org": [EXAMPLE_ORG_IP],
|
|
"foo": [FOO_LOCAL_IP],
|
|
"foo.com": [FOO_COM_IP],
|
|
"127.0.0.7": ["127.0.0.7"],
|
|
"::7": ["::7"],
|
|
},
|
|
)
|
|
|
|
# Lots of tests were written expecting MemoryReactorClock and the
|
|
# reactor seen by the SUT to be the same object.
|
|
drr.tcpClients = mrc.tcpClients
|
|
drr.advance = mrc.advance
|
|
return drr
|
|
|
|
class StubEndpoint:
|
|
"""
|
|
Endpoint that wraps existing endpoint, substitutes StubHTTPProtocol, and
|
|
resulting protocol instances are attached to the given test case.
|
|
"""
|
|
|
|
def __init__(self, endpoint, testCase):
|
|
self.endpoint = endpoint
|
|
self.testCase = testCase
|
|
|
|
def nothing():
|
|
"""this function does nothing"""
|
|
|
|
self.factory = _HTTP11ClientFactory(nothing, repr(self.endpoint))
|
|
self.protocol = StubHTTPProtocol()
|
|
self.factory.buildProtocol = lambda addr: self.protocol
|
|
|
|
def connect(self, ignoredFactory):
|
|
self.testCase.protocol = self.protocol
|
|
self.endpoint.connect(self.factory)
|
|
return succeed(self.protocol)
|
|
|
|
def buildAgentForWrapperTest(self, reactor):
|
|
"""
|
|
Return an Agent suitable for use in tests that wrap the Agent and want
|
|
both a fake reactor and StubHTTPProtocol.
|
|
"""
|
|
agent = client.Agent(reactor)
|
|
_oldGetEndpoint = agent._getEndpoint
|
|
agent._getEndpoint = lambda *args: (
|
|
self.StubEndpoint(_oldGetEndpoint(*args), self)
|
|
)
|
|
return agent
|
|
|
|
def connect(self, factory):
|
|
"""
|
|
Fake implementation of an endpoint which synchronously
|
|
succeeds with an instance of L{StubHTTPProtocol} for ease of
|
|
testing.
|
|
"""
|
|
protocol = StubHTTPProtocol()
|
|
protocol.makeConnection(None)
|
|
self.protocol = protocol
|
|
return succeed(protocol)
|
|
|
|
|
|
class DummyEndpoint:
|
|
"""
|
|
An endpoint that uses a fake transport.
|
|
"""
|
|
|
|
def connect(self, factory):
|
|
protocol = factory.buildProtocol(None)
|
|
protocol.makeConnection(StringTransport())
|
|
return succeed(protocol)
|
|
|
|
|
|
class BadEndpoint:
|
|
"""
|
|
An endpoint that shouldn't be called.
|
|
"""
|
|
|
|
def connect(self, factory):
|
|
raise RuntimeError("This endpoint should not have been used.")
|
|
|
|
|
|
class DummyFactory(Factory):
|
|
"""
|
|
Create C{StubHTTPProtocol} instances.
|
|
"""
|
|
|
|
def __init__(self, quiescentCallback, metadata):
|
|
pass
|
|
|
|
protocol = StubHTTPProtocol
|
|
|
|
|
|
class HTTPConnectionPoolTests(TestCase, FakeReactorAndConnectMixin):
|
|
"""
|
|
Tests for the L{HTTPConnectionPool} class.
|
|
"""
|
|
|
|
def setUp(self):
|
|
self.fakeReactor = self.createReactor()
|
|
self.pool = HTTPConnectionPool(self.fakeReactor)
|
|
self.pool._factory = DummyFactory
|
|
# The retry code path is tested in HTTPConnectionPoolRetryTests:
|
|
self.pool.retryAutomatically = False
|
|
|
|
def test_getReturnsNewIfCacheEmpty(self):
|
|
"""
|
|
If there are no cached connections,
|
|
L{HTTPConnectionPool.getConnection} returns a new connection.
|
|
"""
|
|
self.assertEqual(self.pool._connections, {})
|
|
|
|
def gotConnection(conn):
|
|
self.assertIsInstance(conn, StubHTTPProtocol)
|
|
# The new connection is not stored in the pool:
|
|
self.assertNotIn(conn, self.pool._connections.values())
|
|
|
|
unknownKey = 12245
|
|
d = self.pool.getConnection(unknownKey, DummyEndpoint())
|
|
return d.addCallback(gotConnection)
|
|
|
|
def test_putStartsTimeout(self):
|
|
"""
|
|
If a connection is put back to the pool, a 240-sec timeout is started.
|
|
|
|
When the timeout hits, the connection is closed and removed from the
|
|
pool.
|
|
"""
|
|
# We start out with one cached connection:
|
|
protocol = StubHTTPProtocol()
|
|
protocol.makeConnection(StringTransport())
|
|
self.pool._putConnection(("http", b"example.com", 80), protocol)
|
|
|
|
# Connection is in pool, still not closed:
|
|
self.assertEqual(protocol.transport.disconnecting, False)
|
|
self.assertIn(protocol, self.pool._connections[("http", b"example.com", 80)])
|
|
|
|
# Advance 239 seconds, still not closed:
|
|
self.fakeReactor.advance(239)
|
|
self.assertEqual(protocol.transport.disconnecting, False)
|
|
self.assertIn(protocol, self.pool._connections[("http", b"example.com", 80)])
|
|
self.assertIn(protocol, self.pool._timeouts)
|
|
|
|
# Advance past 240 seconds, connection will be closed:
|
|
self.fakeReactor.advance(1.1)
|
|
self.assertEqual(protocol.transport.disconnecting, True)
|
|
self.assertNotIn(protocol, self.pool._connections[("http", b"example.com", 80)])
|
|
self.assertNotIn(protocol, self.pool._timeouts)
|
|
|
|
def test_putExceedsMaxPersistent(self):
|
|
"""
|
|
If an idle connection is put back in the cache and the max number of
|
|
persistent connections has been exceeded, one of the connections is
|
|
closed and removed from the cache.
|
|
"""
|
|
pool = self.pool
|
|
|
|
# We start out with two cached connection, the max:
|
|
origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
|
|
for p in origCached:
|
|
p.makeConnection(StringTransport())
|
|
pool._putConnection(("http", b"example.com", 80), p)
|
|
self.assertEqual(pool._connections[("http", b"example.com", 80)], origCached)
|
|
timeouts = pool._timeouts.copy()
|
|
|
|
# Now we add another one:
|
|
newProtocol = StubHTTPProtocol()
|
|
newProtocol.makeConnection(StringTransport())
|
|
pool._putConnection(("http", b"example.com", 80), newProtocol)
|
|
|
|
# The oldest cached connections will be removed and disconnected:
|
|
newCached = pool._connections[("http", b"example.com", 80)]
|
|
self.assertEqual(len(newCached), 2)
|
|
self.assertEqual(newCached, [origCached[1], newProtocol])
|
|
self.assertEqual([p.transport.disconnecting for p in newCached], [False, False])
|
|
self.assertEqual(origCached[0].transport.disconnecting, True)
|
|
self.assertTrue(timeouts[origCached[0]].cancelled)
|
|
self.assertNotIn(origCached[0], pool._timeouts)
|
|
|
|
def test_maxPersistentPerHost(self):
|
|
"""
|
|
C{maxPersistentPerHost} is enforced per C{(scheme, host, port)}:
|
|
different keys have different max connections.
|
|
"""
|
|
|
|
def addProtocol(scheme, host, port):
|
|
p = StubHTTPProtocol()
|
|
p.makeConnection(StringTransport())
|
|
self.pool._putConnection((scheme, host, port), p)
|
|
return p
|
|
|
|
persistent = []
|
|
persistent.append(addProtocol("http", b"example.com", 80))
|
|
persistent.append(addProtocol("http", b"example.com", 80))
|
|
addProtocol("https", b"example.com", 443)
|
|
addProtocol("http", b"www2.example.com", 80)
|
|
|
|
self.assertEqual(
|
|
self.pool._connections[("http", b"example.com", 80)], persistent
|
|
)
|
|
self.assertEqual(len(self.pool._connections[("https", b"example.com", 443)]), 1)
|
|
self.assertEqual(
|
|
len(self.pool._connections[("http", b"www2.example.com", 80)]), 1
|
|
)
|
|
|
|
def test_getCachedConnection(self):
|
|
"""
|
|
Getting an address which has a cached connection returns the cached
|
|
connection, removes it from the cache and cancels its timeout.
|
|
"""
|
|
# We start out with one cached connection:
|
|
protocol = StubHTTPProtocol()
|
|
protocol.makeConnection(StringTransport())
|
|
self.pool._putConnection(("http", b"example.com", 80), protocol)
|
|
|
|
def gotConnection(conn):
|
|
# We got the cached connection:
|
|
self.assertIdentical(protocol, conn)
|
|
self.assertNotIn(conn, self.pool._connections[("http", b"example.com", 80)])
|
|
# And the timeout was cancelled:
|
|
self.fakeReactor.advance(241)
|
|
self.assertEqual(conn.transport.disconnecting, False)
|
|
self.assertNotIn(conn, self.pool._timeouts)
|
|
|
|
return self.pool.getConnection(
|
|
("http", b"example.com", 80),
|
|
BadEndpoint(),
|
|
).addCallback(gotConnection)
|
|
|
|
def test_newConnection(self):
|
|
"""
|
|
The pool's C{_newConnection} method constructs a new connection.
|
|
"""
|
|
# We start out with one cached connection:
|
|
protocol = StubHTTPProtocol()
|
|
protocol.makeConnection(StringTransport())
|
|
key = 12245
|
|
self.pool._putConnection(key, protocol)
|
|
|
|
def gotConnection(newConnection):
|
|
# We got a new connection:
|
|
self.assertNotIdentical(protocol, newConnection)
|
|
# And the old connection is still there:
|
|
self.assertIn(protocol, self.pool._connections[key])
|
|
# While the new connection is not:
|
|
self.assertNotIn(newConnection, self.pool._connections.values())
|
|
|
|
d = self.pool._newConnection(key, DummyEndpoint())
|
|
return d.addCallback(gotConnection)
|
|
|
|
def test_getSkipsDisconnected(self):
|
|
"""
|
|
When getting connections out of the cache, disconnected connections
|
|
are removed and not returned.
|
|
"""
|
|
pool = self.pool
|
|
key = ("http", b"example.com", 80)
|
|
|
|
# We start out with two cached connection, the max:
|
|
origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
|
|
for p in origCached:
|
|
p.makeConnection(StringTransport())
|
|
pool._putConnection(key, p)
|
|
self.assertEqual(pool._connections[key], origCached)
|
|
|
|
# We close the first one:
|
|
origCached[0].state = "DISCONNECTED"
|
|
|
|
# Now, when we retrive connections we should get the *second* one:
|
|
result = []
|
|
self.pool.getConnection(key, BadEndpoint()).addCallback(result.append)
|
|
self.assertIdentical(result[0], origCached[1])
|
|
|
|
# And both the disconnected and removed connections should be out of
|
|
# the cache:
|
|
self.assertEqual(pool._connections[key], [])
|
|
self.assertEqual(pool._timeouts, {})
|
|
|
|
def test_putNotQuiescent(self):
|
|
"""
|
|
If a non-quiescent connection is put back in the cache, an error is
|
|
logged.
|
|
"""
|
|
protocol = StubHTTPProtocol()
|
|
# By default state is QUIESCENT
|
|
self.assertEqual(protocol.state, "QUIESCENT")
|
|
|
|
logObserver = EventLoggingObserver.createWithCleanup(self, globalLogPublisher)
|
|
|
|
protocol.state = "NOTQUIESCENT"
|
|
self.pool._putConnection(("http", b"example.com", 80), protocol)
|
|
self.assertEquals(1, len(logObserver))
|
|
|
|
event = logObserver[0]
|
|
f = event["log_failure"]
|
|
|
|
self.assertIsInstance(f.value, RuntimeError)
|
|
self.assertEqual(
|
|
f.getErrorMessage(), "BUG: Non-quiescent protocol added to connection pool."
|
|
)
|
|
self.assertIdentical(
|
|
None, self.pool._connections.get(("http", b"example.com", 80))
|
|
)
|
|
self.flushLoggedErrors(RuntimeError)
|
|
|
|
def test_getUsesQuiescentCallback(self):
|
|
"""
|
|
When L{HTTPConnectionPool.getConnection} connects, it returns a
|
|
C{Deferred} that fires with an instance of L{HTTP11ClientProtocol}
|
|
that has the correct quiescent callback attached. When this callback
|
|
is called the protocol is returned to the cache correctly, using the
|
|
right key.
|
|
"""
|
|
|
|
class StringEndpoint:
|
|
def connect(self, factory):
|
|
p = factory.buildProtocol(None)
|
|
p.makeConnection(StringTransport())
|
|
return succeed(p)
|
|
|
|
pool = HTTPConnectionPool(self.fakeReactor, True)
|
|
pool.retryAutomatically = False
|
|
result = []
|
|
key = "a key"
|
|
pool.getConnection(key, StringEndpoint()).addCallback(result.append)
|
|
protocol = result[0]
|
|
self.assertIsInstance(protocol, HTTP11ClientProtocol)
|
|
|
|
# Now that we have protocol instance, lets try to put it back in the
|
|
# pool:
|
|
protocol._state = "QUIESCENT"
|
|
protocol._quiescentCallback(protocol)
|
|
|
|
# If we try to retrive a connection to same destination again, we
|
|
# should get the same protocol, because it should've been added back
|
|
# to the pool:
|
|
result2 = []
|
|
pool.getConnection(key, StringEndpoint()).addCallback(result2.append)
|
|
self.assertIdentical(result2[0], protocol)
|
|
|
|
def test_closeCachedConnections(self):
|
|
"""
|
|
L{HTTPConnectionPool.closeCachedConnections} closes all cached
|
|
connections and removes them from the cache. It returns a Deferred
|
|
that fires when they have all lost their connections.
|
|
"""
|
|
persistent = []
|
|
|
|
def addProtocol(scheme, host, port):
|
|
p = HTTP11ClientProtocol()
|
|
p.makeConnection(StringTransport())
|
|
self.pool._putConnection((scheme, host, port), p)
|
|
persistent.append(p)
|
|
|
|
addProtocol("http", b"example.com", 80)
|
|
addProtocol("http", b"www2.example.com", 80)
|
|
doneDeferred = self.pool.closeCachedConnections()
|
|
|
|
# Connections have begun disconnecting:
|
|
for p in persistent:
|
|
self.assertEqual(p.transport.disconnecting, True)
|
|
self.assertEqual(self.pool._connections, {})
|
|
# All timeouts were cancelled and removed:
|
|
for dc in self.fakeReactor.getDelayedCalls():
|
|
self.assertEqual(dc.cancelled, True)
|
|
self.assertEqual(self.pool._timeouts, {})
|
|
|
|
# Returned Deferred fires when all connections have been closed:
|
|
result = []
|
|
doneDeferred.addCallback(result.append)
|
|
self.assertEqual(result, [])
|
|
persistent[0].connectionLost(Failure(ConnectionDone()))
|
|
self.assertEqual(result, [])
|
|
persistent[1].connectionLost(Failure(ConnectionDone()))
|
|
self.assertEqual(result, [None])
|
|
|
|
def test_cancelGetConnectionCancelsEndpointConnect(self):
|
|
"""
|
|
Cancelling the C{Deferred} returned from
|
|
L{HTTPConnectionPool.getConnection} cancels the C{Deferred} returned
|
|
by opening a new connection with the given endpoint.
|
|
"""
|
|
self.assertEqual(self.pool._connections, {})
|
|
connectionResult = Deferred()
|
|
|
|
class Endpoint:
|
|
def connect(self, factory):
|
|
return connectionResult
|
|
|
|
d = self.pool.getConnection(12345, Endpoint())
|
|
d.cancel()
|
|
self.assertEqual(self.failureResultOf(connectionResult).type, CancelledError)
|
|
|
|
|
|
class AgentTestsMixin:
|
|
"""
|
|
Tests for any L{IAgent} implementation.
|
|
"""
|
|
|
|
def test_interface(self):
|
|
"""
|
|
The agent object provides L{IAgent}.
|
|
"""
|
|
self.assertTrue(verifyObject(IAgent, self.makeAgent()))
|
|
|
|
|
|
class IntegrationTestingMixin:
|
|
"""
|
|
Transport-to-Agent integration tests for both HTTP and HTTPS.
|
|
"""
|
|
|
|
def test_integrationTestIPv4(self):
|
|
"""
|
|
L{Agent} works over IPv4.
|
|
"""
|
|
self.integrationTest(b"example.com", EXAMPLE_COM_IP, IPv4Address)
|
|
|
|
def test_integrationTestIPv4Address(self):
|
|
"""
|
|
L{Agent} works over IPv4 when hostname is an IPv4 address.
|
|
"""
|
|
self.integrationTest(b"127.0.0.7", "127.0.0.7", IPv4Address)
|
|
|
|
def test_integrationTestIPv6(self):
|
|
"""
|
|
L{Agent} works over IPv6.
|
|
"""
|
|
self.integrationTest(b"ipv6.example.com", EXAMPLE_COM_V6_IP, IPv6Address)
|
|
|
|
def test_integrationTestIPv6Address(self):
|
|
"""
|
|
L{Agent} works over IPv6 when hostname is an IPv6 address.
|
|
"""
|
|
self.integrationTest(b"[::7]", "::7", IPv6Address)
|
|
|
|
def integrationTest(
|
|
self,
|
|
hostName,
|
|
expectedAddress,
|
|
addressType,
|
|
serverWrapper=lambda server, _: server,
|
|
createAgent=client.Agent,
|
|
scheme=b"http",
|
|
):
|
|
"""
|
|
L{Agent} will make a TCP connection, send an HTTP request, and return a
|
|
L{Deferred} that fires when the response has been received.
|
|
|
|
@param hostName: The hostname to interpolate into the URL to be
|
|
requested.
|
|
@type hostName: L{bytes}
|
|
|
|
@param expectedAddress: The expected address string.
|
|
@type expectedAddress: L{bytes}
|
|
|
|
@param addressType: The class to construct an address out of.
|
|
@type addressType: L{type}
|
|
|
|
@param serverWrapper: A callable that takes a protocol factory and a
|
|
``Clock`` and returns a protocol factory; used to wrap the server /
|
|
responder side in a TLS server.
|
|
@type serverWrapper:
|
|
serverWrapper(L{twisted.internet.interfaces.IProtocolFactory}) ->
|
|
L{twisted.internet.interfaces.IProtocolFactory}
|
|
|
|
@param createAgent: A callable that takes a reactor and produces an
|
|
L{IAgent}; used to construct an agent with an appropriate trust
|
|
root for TLS.
|
|
@type createAgent: createAgent(reactor) -> L{IAgent}
|
|
|
|
@param scheme: The scheme to test, C{http} or C{https}
|
|
@type scheme: L{bytes}
|
|
"""
|
|
reactor = self.createReactor()
|
|
if sslPresent:
|
|
# We have no way to tell the client to use our test reactor so we
|
|
# have to patch it.
|
|
self.patch(tls, "_get_default_clock", lambda: reactor)
|
|
agent = createAgent(reactor)
|
|
deferred = agent.request(b"GET", scheme + b"://" + hostName + b"/")
|
|
host, port, factory, timeout, bind = reactor.tcpClients[0]
|
|
self.assertEqual(host, expectedAddress)
|
|
peerAddress = addressType("TCP", host, port)
|
|
clientProtocol = factory.buildProtocol(peerAddress)
|
|
clientTransport = FakeTransport(clientProtocol, False, peerAddress=peerAddress)
|
|
clientProtocol.makeConnection(clientTransport)
|
|
|
|
@Factory.forProtocol
|
|
def accumulator():
|
|
ap = AccumulatingProtocol()
|
|
accumulator.currentProtocol = ap
|
|
return ap
|
|
|
|
accumulator.currentProtocol = None
|
|
accumulator.protocolConnectionMade = None
|
|
wrapper = serverWrapper(accumulator, reactor).buildProtocol(None)
|
|
serverTransport = FakeTransport(wrapper, True)
|
|
wrapper.makeConnection(serverTransport)
|
|
pump = IOPump(
|
|
clientProtocol,
|
|
wrapper,
|
|
clientTransport,
|
|
serverTransport,
|
|
False,
|
|
clock=reactor,
|
|
)
|
|
pump.flush()
|
|
self.assertNoResult(deferred)
|
|
lines = accumulator.currentProtocol.data.split(b"\r\n")
|
|
self.assertTrue(lines[0].startswith(b"GET / HTTP"), lines[0])
|
|
headers = dict([line.split(b": ", 1) for line in lines[1:] if line])
|
|
self.assertEqual(headers[b"Host"], hostName)
|
|
self.assertNoResult(deferred)
|
|
accumulator.currentProtocol.transport.write(
|
|
b"HTTP/1.1 200 OK"
|
|
b"\r\nX-An-Header: an-value\r\n"
|
|
b"\r\nContent-length: 12\r\n\r\n"
|
|
b"hello world!"
|
|
)
|
|
pump.flush()
|
|
response = self.successResultOf(deferred)
|
|
self.assertEquals(
|
|
response.headers.getRawHeaders(b"x-an-header")[0], b"an-value"
|
|
)
|
|
|
|
|
|
@implementer(IAgentEndpointFactory)
|
|
class StubEndpointFactory:
|
|
"""
|
|
A stub L{IAgentEndpointFactory} for use in testing.
|
|
"""
|
|
|
|
def endpointForURI(self, uri):
|
|
"""
|
|
Testing implementation.
|
|
|
|
@param uri: A L{URI}.
|
|
|
|
@return: C{(scheme, host, port)} of passed in URI; violation of
|
|
interface but useful for testing.
|
|
@rtype: L{tuple}
|
|
"""
|
|
return (uri.scheme, uri.host, uri.port)
|
|
|
|
|
|
class AgentTests(
|
|
TestCase, FakeReactorAndConnectMixin, AgentTestsMixin, IntegrationTestingMixin
|
|
):
|
|
"""
|
|
Tests for the new HTTP client API provided by L{Agent}.
|
|
"""
|
|
|
|
def makeAgent(self):
|
|
"""
|
|
@return: a new L{twisted.web.client.Agent} instance
|
|
"""
|
|
return client.Agent(self.reactor)
|
|
|
|
def setUp(self):
|
|
"""
|
|
Create an L{Agent} wrapped around a fake reactor.
|
|
"""
|
|
self.reactor = self.createReactor()
|
|
self.agent = self.makeAgent()
|
|
|
|
def test_defaultPool(self):
|
|
"""
|
|
If no pool is passed in, the L{Agent} creates a non-persistent pool.
|
|
"""
|
|
agent = client.Agent(self.reactor)
|
|
self.assertIsInstance(agent._pool, HTTPConnectionPool)
|
|
self.assertEqual(agent._pool.persistent, False)
|
|
self.assertIdentical(agent._reactor, agent._pool._reactor)
|
|
|
|
def test_persistent(self):
|
|
"""
|
|
If C{persistent} is set to C{True} on the L{HTTPConnectionPool} (the
|
|
default), C{Request}s are created with their C{persistent} flag set to
|
|
C{True}.
|
|
"""
|
|
pool = HTTPConnectionPool(self.reactor)
|
|
agent = client.Agent(self.reactor, pool=pool)
|
|
agent._getEndpoint = lambda *args: self
|
|
agent.request(b"GET", b"http://127.0.0.1")
|
|
self.assertEqual(self.protocol.requests[0][0].persistent, True)
|
|
|
|
def test_nonPersistent(self):
|
|
"""
|
|
If C{persistent} is set to C{False} when creating the
|
|
L{HTTPConnectionPool}, C{Request}s are created with their
|
|
C{persistent} flag set to C{False}.
|
|
|
|
Elsewhere in the tests for the underlying HTTP code we ensure that
|
|
this will result in the disconnection of the HTTP protocol once the
|
|
request is done, so that the connection will not be returned to the
|
|
pool.
|
|
"""
|
|
pool = HTTPConnectionPool(self.reactor, persistent=False)
|
|
agent = client.Agent(self.reactor, pool=pool)
|
|
agent._getEndpoint = lambda *args: self
|
|
agent.request(b"GET", b"http://127.0.0.1")
|
|
self.assertEqual(self.protocol.requests[0][0].persistent, False)
|
|
|
|
def test_connectUsesConnectionPool(self):
|
|
"""
|
|
When a connection is made by the Agent, it uses its pool's
|
|
C{getConnection} method to do so, with the endpoint returned by
|
|
C{self._getEndpoint}. The key used is C{(scheme, host, port)}.
|
|
"""
|
|
endpoint = DummyEndpoint()
|
|
|
|
class MyAgent(client.Agent):
|
|
def _getEndpoint(this, uri):
|
|
self.assertEqual(
|
|
(uri.scheme, uri.host, uri.port), (b"http", b"foo", 80)
|
|
)
|
|
return endpoint
|
|
|
|
class DummyPool:
|
|
connected = False
|
|
persistent = False
|
|
|
|
def getConnection(this, key, ep):
|
|
this.connected = True
|
|
self.assertEqual(ep, endpoint)
|
|
# This is the key the default Agent uses, others will have
|
|
# different keys:
|
|
self.assertEqual(key, (b"http", b"foo", 80))
|
|
return defer.succeed(StubHTTPProtocol())
|
|
|
|
pool = DummyPool()
|
|
agent = MyAgent(self.reactor, pool=pool)
|
|
self.assertIdentical(pool, agent._pool)
|
|
|
|
headers = http_headers.Headers()
|
|
headers.addRawHeader(b"host", b"foo")
|
|
bodyProducer = object()
|
|
agent.request(
|
|
b"GET", b"http://foo/", bodyProducer=bodyProducer, headers=headers
|
|
)
|
|
self.assertEqual(agent._pool.connected, True)
|
|
|
|
def test_nonBytesMethod(self):
|
|
"""
|
|
L{Agent.request} raises L{TypeError} when the C{method} argument isn't
|
|
L{bytes}.
|
|
"""
|
|
self.assertRaises(TypeError, self.agent.request, "GET", b"http://foo.example/")
|
|
|
|
def test_unsupportedScheme(self):
|
|
"""
|
|
L{Agent.request} returns a L{Deferred} which fails with
|
|
L{SchemeNotSupported} if the scheme of the URI passed to it is not
|
|
C{'http'}.
|
|
"""
|
|
return self.assertFailure(
|
|
self.agent.request(b"GET", b"mailto:alice@example.com"), SchemeNotSupported
|
|
)
|
|
|
|
def test_connectionFailed(self):
|
|
"""
|
|
The L{Deferred} returned by L{Agent.request} fires with a L{Failure} if
|
|
the TCP connection attempt fails.
|
|
"""
|
|
result = self.agent.request(b"GET", b"http://foo/")
|
|
# Cause the connection to be refused
|
|
host, port, factory = self.reactor.tcpClients.pop()[:3]
|
|
factory.clientConnectionFailed(None, Failure(ConnectionRefusedError()))
|
|
self.reactor.advance(10)
|
|
# ^ https://twistedmatrix.com/trac/ticket/8202
|
|
self.failureResultOf(result, ConnectionRefusedError)
|
|
|
|
def test_connectHTTP(self):
|
|
"""
|
|
L{Agent._getEndpoint} return a C{HostnameEndpoint} when passed a scheme
|
|
of C{'http'}.
|
|
"""
|
|
expectedHost = b"example.com"
|
|
expectedPort = 1234
|
|
endpoint = self.agent._getEndpoint(
|
|
URI.fromBytes(b"http://%b:%d" % (expectedHost, expectedPort))
|
|
)
|
|
self.assertEqual(endpoint._hostStr, "example.com")
|
|
self.assertEqual(endpoint._port, expectedPort)
|
|
self.assertIsInstance(endpoint, HostnameEndpoint)
|
|
|
|
def test_nonDecodableURI(self):
|
|
"""
|
|
L{Agent._getEndpoint} when given a non-ASCII decodable URI will raise a
|
|
L{ValueError} saying such.
|
|
"""
|
|
uri = URI.fromBytes(b"http://example.com:80")
|
|
uri.host = "\u2603.com".encode()
|
|
|
|
with self.assertRaises(ValueError) as e:
|
|
self.agent._getEndpoint(uri)
|
|
|
|
self.assertEqual(
|
|
e.exception.args[0],
|
|
(
|
|
"The host of the provided URI ({reprout}) contains "
|
|
"non-ASCII octets, it should be ASCII "
|
|
"decodable."
|
|
).format(reprout=repr(uri.host)),
|
|
)
|
|
|
|
def test_hostProvided(self):
|
|
"""
|
|
If L{None} is passed to L{Agent.request} for the C{headers} parameter,
|
|
a L{Headers} instance is created for the request and a I{Host} header
|
|
added to it.
|
|
"""
|
|
self.agent._getEndpoint = lambda *args: self
|
|
self.agent.request(b"GET", b"http://example.com/foo?bar")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertEqual(req.headers.getRawHeaders(b"host"), [b"example.com"])
|
|
|
|
def test_hostIPv6Bracketed(self):
|
|
"""
|
|
If an IPv6 address is used in the C{uri} passed to L{Agent.request},
|
|
the computed I{Host} header needs to be bracketed.
|
|
"""
|
|
self.agent._getEndpoint = lambda *args: self
|
|
self.agent.request(b"GET", b"http://[::1]/")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertEqual(req.headers.getRawHeaders(b"host"), [b"[::1]"])
|
|
|
|
def test_hostOverride(self):
|
|
"""
|
|
If the headers passed to L{Agent.request} includes a value for the
|
|
I{Host} header, that value takes precedence over the one which would
|
|
otherwise be automatically provided.
|
|
"""
|
|
headers = http_headers.Headers({b"foo": [b"bar"], b"host": [b"quux"]})
|
|
self.agent._getEndpoint = lambda *args: self
|
|
self.agent.request(b"GET", b"http://example.com/foo?bar", headers)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertEqual(req.headers.getRawHeaders(b"host"), [b"quux"])
|
|
|
|
def test_headersUnmodified(self):
|
|
"""
|
|
If a I{Host} header must be added to the request, the L{Headers}
|
|
instance passed to L{Agent.request} is not modified.
|
|
"""
|
|
headers = http_headers.Headers()
|
|
self.agent._getEndpoint = lambda *args: self
|
|
self.agent.request(b"GET", b"http://example.com/foo", headers)
|
|
|
|
protocol = self.protocol
|
|
|
|
# The request should have been issued.
|
|
self.assertEqual(len(protocol.requests), 1)
|
|
# And the headers object passed in should not have changed.
|
|
self.assertEqual(headers, http_headers.Headers())
|
|
|
|
def test_hostValueStandardHTTP(self):
|
|
"""
|
|
When passed a scheme of C{'http'} and a port of C{80},
|
|
L{Agent._computeHostValue} returns a string giving just
|
|
the host name passed to it.
|
|
"""
|
|
self.assertEqual(
|
|
self.agent._computeHostValue(b"http", b"example.com", 80), b"example.com"
|
|
)
|
|
|
|
def test_hostValueNonStandardHTTP(self):
|
|
"""
|
|
When passed a scheme of C{'http'} and a port other than C{80},
|
|
L{Agent._computeHostValue} returns a string giving the
|
|
host passed to it joined together with the port number by C{":"}.
|
|
"""
|
|
self.assertEqual(
|
|
self.agent._computeHostValue(b"http", b"example.com", 54321),
|
|
b"example.com:54321",
|
|
)
|
|
|
|
def test_hostValueStandardHTTPS(self):
|
|
"""
|
|
When passed a scheme of C{'https'} and a port of C{443},
|
|
L{Agent._computeHostValue} returns a string giving just
|
|
the host name passed to it.
|
|
"""
|
|
self.assertEqual(
|
|
self.agent._computeHostValue(b"https", b"example.com", 443), b"example.com"
|
|
)
|
|
|
|
def test_hostValueNonStandardHTTPS(self):
|
|
"""
|
|
When passed a scheme of C{'https'} and a port other than C{443},
|
|
L{Agent._computeHostValue} returns a string giving the
|
|
host passed to it joined together with the port number by C{":"}.
|
|
"""
|
|
self.assertEqual(
|
|
self.agent._computeHostValue(b"https", b"example.com", 54321),
|
|
b"example.com:54321",
|
|
)
|
|
|
|
def test_request(self):
|
|
"""
|
|
L{Agent.request} establishes a new connection to the host indicated by
|
|
the host part of the URI passed to it and issues a request using the
|
|
method, the path portion of the URI, the headers, and the body producer
|
|
passed to it. It returns a L{Deferred} which fires with an
|
|
L{IResponse} from the server.
|
|
"""
|
|
self.agent._getEndpoint = lambda *args: self
|
|
|
|
headers = http_headers.Headers({b"foo": [b"bar"]})
|
|
# Just going to check the body for identity, so it doesn't need to be
|
|
# real.
|
|
body = object()
|
|
self.agent.request(b"GET", b"http://example.com:1234/foo?bar", headers, body)
|
|
|
|
protocol = self.protocol
|
|
|
|
# The request should be issued.
|
|
self.assertEqual(len(protocol.requests), 1)
|
|
req, res = protocol.requests.pop()
|
|
self.assertIsInstance(req, Request)
|
|
self.assertEqual(req.method, b"GET")
|
|
self.assertEqual(req.uri, b"/foo?bar")
|
|
self.assertEqual(
|
|
req.headers,
|
|
http_headers.Headers({b"foo": [b"bar"], b"host": [b"example.com:1234"]}),
|
|
)
|
|
self.assertIdentical(req.bodyProducer, body)
|
|
|
|
def test_connectTimeout(self):
|
|
"""
|
|
L{Agent} takes a C{connectTimeout} argument which is forwarded to the
|
|
following C{connectTCP} agent.
|
|
"""
|
|
agent = client.Agent(self.reactor, connectTimeout=5)
|
|
agent.request(b"GET", b"http://foo/")
|
|
timeout = self.reactor.tcpClients.pop()[3]
|
|
self.assertEqual(5, timeout)
|
|
|
|
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
|
|
def test_connectTimeoutHTTPS(self):
|
|
"""
|
|
L{Agent} takes a C{connectTimeout} argument which is forwarded to the
|
|
following C{connectTCP} call.
|
|
"""
|
|
agent = client.Agent(self.reactor, connectTimeout=5)
|
|
agent.request(b"GET", b"https://foo/")
|
|
timeout = self.reactor.tcpClients.pop()[3]
|
|
self.assertEqual(5, timeout)
|
|
|
|
def test_bindAddress(self):
|
|
"""
|
|
L{Agent} takes a C{bindAddress} argument which is forwarded to the
|
|
following C{connectTCP} call.
|
|
"""
|
|
agent = client.Agent(self.reactor, bindAddress="192.168.0.1")
|
|
agent.request(b"GET", b"http://foo/")
|
|
address = self.reactor.tcpClients.pop()[4]
|
|
self.assertEqual("192.168.0.1", address)
|
|
|
|
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
|
|
def test_bindAddressSSL(self):
|
|
"""
|
|
L{Agent} takes a C{bindAddress} argument which is forwarded to the
|
|
following C{connectSSL} call.
|
|
"""
|
|
agent = client.Agent(self.reactor, bindAddress="192.168.0.1")
|
|
agent.request(b"GET", b"https://foo/")
|
|
address = self.reactor.tcpClients.pop()[4]
|
|
self.assertEqual("192.168.0.1", address)
|
|
|
|
def test_responseIncludesRequest(self):
|
|
"""
|
|
L{Response}s returned by L{Agent.request} have a reference to the
|
|
L{Request} that was originally issued.
|
|
"""
|
|
uri = b"http://example.com/"
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
d = agent.request(b"GET", uri)
|
|
|
|
# The request should be issued.
|
|
self.assertEqual(len(self.protocol.requests), 1)
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertIsInstance(req, Request)
|
|
|
|
resp = client.Response._construct(
|
|
(b"HTTP", 1, 1), 200, b"OK", Headers({}), None, req
|
|
)
|
|
res.callback(resp)
|
|
|
|
response = self.successResultOf(d)
|
|
self.assertEqual(
|
|
(
|
|
response.request.method,
|
|
response.request.absoluteURI,
|
|
response.request.headers,
|
|
),
|
|
(req.method, req.absoluteURI, req.headers),
|
|
)
|
|
|
|
def test_requestAbsoluteURI(self):
|
|
"""
|
|
L{Request.absoluteURI} is the absolute URI of the request.
|
|
"""
|
|
uri = b"http://example.com/foo;1234?bar#frag"
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
agent.request(b"GET", uri)
|
|
|
|
# The request should be issued.
|
|
self.assertEqual(len(self.protocol.requests), 1)
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertIsInstance(req, Request)
|
|
self.assertEqual(req.absoluteURI, uri)
|
|
|
|
def test_requestMissingAbsoluteURI(self):
|
|
"""
|
|
L{Request.absoluteURI} is L{None} if L{Request._parsedURI} is L{None}.
|
|
"""
|
|
request = client.Request(b"FOO", b"/", Headers(), None)
|
|
self.assertIdentical(request.absoluteURI, None)
|
|
|
|
def test_endpointFactory(self):
|
|
"""
|
|
L{Agent.usingEndpointFactory} creates an L{Agent} that uses the given
|
|
factory to create endpoints.
|
|
"""
|
|
factory = StubEndpointFactory()
|
|
agent = client.Agent.usingEndpointFactory(None, endpointFactory=factory)
|
|
uri = URI.fromBytes(b"http://example.com/")
|
|
returnedEndpoint = agent._getEndpoint(uri)
|
|
self.assertEqual(returnedEndpoint, (b"http", b"example.com", 80))
|
|
|
|
def test_endpointFactoryDefaultPool(self):
|
|
"""
|
|
If no pool is passed in to L{Agent.usingEndpointFactory}, a default
|
|
pool is constructed with no persistent connections.
|
|
"""
|
|
agent = client.Agent.usingEndpointFactory(self.reactor, StubEndpointFactory())
|
|
pool = agent._pool
|
|
self.assertEqual(
|
|
(pool.__class__, pool.persistent, pool._reactor),
|
|
(HTTPConnectionPool, False, agent._reactor),
|
|
)
|
|
|
|
def test_endpointFactoryPool(self):
|
|
"""
|
|
If a pool is passed in to L{Agent.usingEndpointFactory} it is used as
|
|
the L{Agent} pool.
|
|
"""
|
|
pool = object()
|
|
agent = client.Agent.usingEndpointFactory(
|
|
self.reactor, StubEndpointFactory(), pool
|
|
)
|
|
self.assertIs(pool, agent._pool)
|
|
|
|
|
|
class AgentMethodInjectionTests(
|
|
FakeReactorAndConnectMixin,
|
|
MethodInjectionTestsMixin,
|
|
SynchronousTestCase,
|
|
):
|
|
"""
|
|
Test L{client.Agent} against HTTP method injections.
|
|
"""
|
|
|
|
def attemptRequestWithMaliciousMethod(self, method):
|
|
"""
|
|
Attempt a request with the provided method.
|
|
|
|
@param method: see L{MethodInjectionTestsMixin}
|
|
"""
|
|
agent = client.Agent(self.createReactor())
|
|
uri = b"http://twisted.invalid"
|
|
agent.request(method, uri, Headers(), None)
|
|
|
|
|
|
class AgentURIInjectionTests(
|
|
FakeReactorAndConnectMixin,
|
|
URIInjectionTestsMixin,
|
|
SynchronousTestCase,
|
|
):
|
|
"""
|
|
Test L{client.Agent} against URI injections.
|
|
"""
|
|
|
|
def attemptRequestWithMaliciousURI(self, uri):
|
|
"""
|
|
Attempt a request with the provided method.
|
|
|
|
@param uri: see L{URIInjectionTestsMixin}
|
|
"""
|
|
agent = client.Agent(self.createReactor())
|
|
method = b"GET"
|
|
agent.request(method, uri, Headers(), None)
|
|
|
|
|
|
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
|
|
class AgentHTTPSTests(TestCase, FakeReactorAndConnectMixin, IntegrationTestingMixin):
|
|
"""
|
|
Tests for the new HTTP client API that depends on SSL.
|
|
"""
|
|
|
|
def makeEndpoint(self, host=b"example.com", port=443):
|
|
"""
|
|
Create an L{Agent} with an https scheme and return its endpoint
|
|
created according to the arguments.
|
|
|
|
@param host: The host for the endpoint.
|
|
@type host: L{bytes}
|
|
|
|
@param port: The port for the endpoint.
|
|
@type port: L{int}
|
|
|
|
@return: An endpoint of an L{Agent} constructed according to args.
|
|
@rtype: L{SSL4ClientEndpoint}
|
|
"""
|
|
return client.Agent(self.createReactor())._getEndpoint(
|
|
URI.fromBytes(b"https://%b:%d/" % (host, port))
|
|
)
|
|
|
|
def test_endpointType(self):
|
|
"""
|
|
L{Agent._getEndpoint} return a L{SSL4ClientEndpoint} when passed a
|
|
scheme of C{'https'}.
|
|
"""
|
|
from twisted.internet.endpoints import _WrapperEndpoint
|
|
|
|
endpoint = self.makeEndpoint()
|
|
self.assertIsInstance(endpoint, _WrapperEndpoint)
|
|
self.assertIsInstance(endpoint._wrappedEndpoint, HostnameEndpoint)
|
|
|
|
def test_hostArgumentIsRespected(self):
|
|
"""
|
|
If a host is passed, the endpoint respects it.
|
|
"""
|
|
endpoint = self.makeEndpoint(host=b"example.com")
|
|
self.assertEqual(endpoint._wrappedEndpoint._hostStr, "example.com")
|
|
|
|
def test_portArgumentIsRespected(self):
|
|
"""
|
|
If a port is passed, the endpoint respects it.
|
|
"""
|
|
expectedPort = 4321
|
|
endpoint = self.makeEndpoint(port=expectedPort)
|
|
self.assertEqual(endpoint._wrappedEndpoint._port, expectedPort)
|
|
|
|
def test_contextFactoryType(self):
|
|
"""
|
|
L{Agent} wraps its connection creator creator and uses modern TLS APIs.
|
|
"""
|
|
endpoint = self.makeEndpoint()
|
|
contextFactory = endpoint._wrapperFactory(None)._connectionCreator
|
|
self.assertIsInstance(contextFactory, ClientTLSOptions)
|
|
self.assertEqual(contextFactory._hostname, "example.com")
|
|
|
|
def test_connectHTTPSCustomConnectionCreator(self):
|
|
"""
|
|
If a custom L{WebClientConnectionCreator}-like object is passed to
|
|
L{Agent.__init__} it will be used to determine the SSL parameters for
|
|
HTTPS requests. When an HTTPS request is made, the hostname and port
|
|
number of the request URL will be passed to the connection creator's
|
|
C{creatorForNetloc} method. The resulting context object will be used
|
|
to establish the SSL connection.
|
|
"""
|
|
expectedHost = b"example.org"
|
|
expectedPort = 20443
|
|
|
|
class JustEnoughConnection:
|
|
handshakeStarted = False
|
|
connectState = False
|
|
|
|
def do_handshake(self):
|
|
"""
|
|
The handshake started. Record that fact.
|
|
"""
|
|
self.handshakeStarted = True
|
|
|
|
def set_connect_state(self):
|
|
"""
|
|
The connection started. Record that fact.
|
|
"""
|
|
self.connectState = True
|
|
|
|
contextArgs = []
|
|
|
|
@implementer(IOpenSSLClientConnectionCreator)
|
|
class JustEnoughCreator:
|
|
def __init__(self, hostname, port):
|
|
self.hostname = hostname
|
|
self.port = port
|
|
|
|
def clientConnectionForTLS(self, tlsProtocol):
|
|
"""
|
|
Implement L{IOpenSSLClientConnectionCreator}.
|
|
|
|
@param tlsProtocol: The TLS protocol.
|
|
@type tlsProtocol: L{TLSMemoryBIOProtocol}
|
|
|
|
@return: C{expectedConnection}
|
|
"""
|
|
contextArgs.append((tlsProtocol, self.hostname, self.port))
|
|
return expectedConnection
|
|
|
|
expectedConnection = JustEnoughConnection()
|
|
|
|
@implementer(IPolicyForHTTPS)
|
|
class StubBrowserLikePolicyForHTTPS:
|
|
def creatorForNetloc(self, hostname, port):
|
|
"""
|
|
Emulate L{BrowserLikePolicyForHTTPS}.
|
|
|
|
@param hostname: The hostname to verify.
|
|
@type hostname: L{bytes}
|
|
|
|
@param port: The port number.
|
|
@type port: L{int}
|
|
|
|
@return: a stub L{IOpenSSLClientConnectionCreator}
|
|
@rtype: L{JustEnoughCreator}
|
|
"""
|
|
return JustEnoughCreator(hostname, port)
|
|
|
|
expectedCreatorCreator = StubBrowserLikePolicyForHTTPS()
|
|
reactor = self.createReactor()
|
|
agent = client.Agent(reactor, expectedCreatorCreator)
|
|
endpoint = agent._getEndpoint(
|
|
URI.fromBytes(b"https://%b:%d" % (expectedHost, expectedPort))
|
|
)
|
|
endpoint.connect(Factory.forProtocol(Protocol))
|
|
tlsFactory = reactor.tcpClients[-1][2]
|
|
tlsProtocol = tlsFactory.buildProtocol(None)
|
|
tlsProtocol.makeConnection(StringTransport())
|
|
tls = contextArgs[0][0]
|
|
self.assertIsInstance(tls, TLSMemoryBIOProtocol)
|
|
self.assertEqual(contextArgs[0][1:], (expectedHost, expectedPort))
|
|
self.assertTrue(expectedConnection.handshakeStarted)
|
|
self.assertTrue(expectedConnection.connectState)
|
|
|
|
def test_deprecatedDuckPolicy(self):
|
|
"""
|
|
Passing something that duck-types I{like} a L{web client context
|
|
factory <twisted.web.client.WebClientContextFactory>} - something that
|
|
does not provide L{IPolicyForHTTPS} - to L{Agent} emits a
|
|
L{DeprecationWarning} even if you don't actually C{import
|
|
WebClientContextFactory} to do it.
|
|
"""
|
|
|
|
def warnMe():
|
|
client.Agent(
|
|
deterministicResolvingReactor(MemoryReactorClock()),
|
|
"does-not-provide-IPolicyForHTTPS",
|
|
)
|
|
|
|
warnMe()
|
|
warnings = self.flushWarnings([warnMe])
|
|
self.assertEqual(len(warnings), 1)
|
|
[warning] = warnings
|
|
self.assertEqual(warning["category"], DeprecationWarning)
|
|
self.assertEqual(
|
|
warning["message"],
|
|
"'does-not-provide-IPolicyForHTTPS' was passed as the HTTPS "
|
|
"policy for an Agent, but it does not provide IPolicyForHTTPS. "
|
|
"Since Twisted 14.0, you must pass a provider of IPolicyForHTTPS.",
|
|
)
|
|
|
|
def test_alternateTrustRoot(self):
|
|
"""
|
|
L{BrowserLikePolicyForHTTPS.creatorForNetloc} returns an
|
|
L{IOpenSSLClientConnectionCreator} provider which will add certificates
|
|
from the given trust root.
|
|
"""
|
|
trustRoot = CustomOpenSSLTrustRoot()
|
|
policy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
|
|
creator = policy.creatorForNetloc(b"thingy", 4321)
|
|
self.assertTrue(trustRoot.called)
|
|
connection = creator.clientConnectionForTLS(None)
|
|
self.assertIs(trustRoot.context, connection.get_context())
|
|
|
|
def integrationTest(self, hostName, expectedAddress, addressType):
|
|
"""
|
|
Wrap L{AgentTestsMixin.integrationTest} with TLS.
|
|
"""
|
|
certHostName = hostName.strip(b"[]")
|
|
authority, server = certificatesForAuthorityAndServer(
|
|
certHostName.decode("ascii")
|
|
)
|
|
|
|
def tlsify(serverFactory, reactor):
|
|
return TLSMemoryBIOFactory(server.options(), False, serverFactory, reactor)
|
|
|
|
def tlsagent(reactor):
|
|
from zope.interface import implementer
|
|
|
|
from twisted.web.iweb import IPolicyForHTTPS
|
|
|
|
@implementer(IPolicyForHTTPS)
|
|
class Policy:
|
|
def creatorForNetloc(self, hostname, port):
|
|
return optionsForClientTLS(
|
|
hostname.decode("ascii"), trustRoot=authority
|
|
)
|
|
|
|
return client.Agent(reactor, contextFactory=Policy())
|
|
|
|
(
|
|
super().integrationTest(
|
|
hostName,
|
|
expectedAddress,
|
|
addressType,
|
|
serverWrapper=tlsify,
|
|
createAgent=tlsagent,
|
|
scheme=b"https",
|
|
)
|
|
)
|
|
|
|
|
|
class WebClientContextFactoryTests(TestCase):
|
|
"""
|
|
Tests for the context factory wrapper for web clients
|
|
L{twisted.web.client.WebClientContextFactory}.
|
|
"""
|
|
|
|
def setUp(self):
|
|
"""
|
|
Get WebClientContextFactory while quashing its deprecation warning.
|
|
"""
|
|
from twisted.web.client import WebClientContextFactory
|
|
|
|
self.warned = self.flushWarnings([WebClientContextFactoryTests.setUp])
|
|
self.webClientContextFactory = WebClientContextFactory
|
|
|
|
def test_deprecated(self):
|
|
"""
|
|
L{twisted.web.client.WebClientContextFactory} is deprecated. Importing
|
|
it displays a warning.
|
|
"""
|
|
self.assertEqual(len(self.warned), 1)
|
|
[warning] = self.warned
|
|
self.assertEqual(warning["category"], DeprecationWarning)
|
|
self.assertEqual(
|
|
warning["message"],
|
|
getDeprecationWarningString(
|
|
self.webClientContextFactory,
|
|
Version("Twisted", 14, 0, 0),
|
|
replacement=BrowserLikePolicyForHTTPS,
|
|
)
|
|
# See https://twistedmatrix.com/trac/ticket/7242
|
|
.replace(";", ":"),
|
|
)
|
|
|
|
@skipIf(sslPresent, "SSL Present.")
|
|
def test_missingSSL(self):
|
|
"""
|
|
If C{getContext} is called and SSL is not available, raise
|
|
L{NotImplementedError}.
|
|
"""
|
|
self.assertRaises(
|
|
NotImplementedError,
|
|
self.webClientContextFactory().getContext,
|
|
b"example.com",
|
|
443,
|
|
)
|
|
|
|
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
|
|
def test_returnsContext(self):
|
|
"""
|
|
If SSL is present, C{getContext} returns a L{OpenSSL.SSL.Context}.
|
|
"""
|
|
ctx = self.webClientContextFactory().getContext("example.com", 443)
|
|
self.assertIsInstance(ctx, ssl.SSL.Context)
|
|
|
|
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
|
|
def test_setsTrustRootOnContextToDefaultTrustRoot(self):
|
|
"""
|
|
The L{CertificateOptions} has C{trustRoot} set to the default trust
|
|
roots.
|
|
"""
|
|
ctx = self.webClientContextFactory()
|
|
certificateOptions = ctx._getCertificateOptions("example.com", 443)
|
|
self.assertIsInstance(certificateOptions.trustRoot, ssl.OpenSSLDefaultPaths)
|
|
|
|
|
|
class HTTPConnectionPoolRetryTests(TestCase, FakeReactorAndConnectMixin):
|
|
"""
|
|
L{client.HTTPConnectionPool}, by using
|
|
L{client._RetryingHTTP11ClientProtocol}, supports retrying requests done
|
|
against previously cached connections.
|
|
"""
|
|
|
|
def test_onlyRetryIdempotentMethods(self):
|
|
"""
|
|
Only GET, HEAD, OPTIONS, TRACE, DELETE methods cause a retry.
|
|
"""
|
|
pool = client.HTTPConnectionPool(None)
|
|
connection = client._RetryingHTTP11ClientProtocol(None, pool)
|
|
self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
|
|
self.assertTrue(connection._shouldRetry(b"HEAD", RequestNotSent(), None))
|
|
self.assertTrue(connection._shouldRetry(b"OPTIONS", RequestNotSent(), None))
|
|
self.assertTrue(connection._shouldRetry(b"TRACE", RequestNotSent(), None))
|
|
self.assertTrue(connection._shouldRetry(b"DELETE", RequestNotSent(), None))
|
|
self.assertFalse(connection._shouldRetry(b"POST", RequestNotSent(), None))
|
|
self.assertFalse(connection._shouldRetry(b"MYMETHOD", RequestNotSent(), None))
|
|
# This will be covered by a different ticket, since we need support
|
|
# for resettable body producers:
|
|
# self.assertTrue(connection._doRetry("PUT", RequestNotSent(), None))
|
|
|
|
def test_onlyRetryIfNoResponseReceived(self):
|
|
"""
|
|
Only L{RequestNotSent}, L{RequestTransmissionFailed} and
|
|
L{ResponseNeverReceived} exceptions cause a retry.
|
|
"""
|
|
pool = client.HTTPConnectionPool(None)
|
|
connection = client._RetryingHTTP11ClientProtocol(None, pool)
|
|
self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
|
|
self.assertTrue(
|
|
connection._shouldRetry(b"GET", RequestTransmissionFailed([]), None)
|
|
)
|
|
self.assertTrue(
|
|
connection._shouldRetry(b"GET", ResponseNeverReceived([]), None)
|
|
)
|
|
self.assertFalse(connection._shouldRetry(b"GET", ResponseFailed([]), None))
|
|
self.assertFalse(
|
|
connection._shouldRetry(b"GET", ConnectionRefusedError(), None)
|
|
)
|
|
|
|
def test_dontRetryIfFailedDueToCancel(self):
|
|
"""
|
|
If a request failed due to the operation being cancelled,
|
|
C{_shouldRetry} returns C{False} to indicate the request should not be
|
|
retried.
|
|
"""
|
|
pool = client.HTTPConnectionPool(None)
|
|
connection = client._RetryingHTTP11ClientProtocol(None, pool)
|
|
exception = ResponseNeverReceived([Failure(defer.CancelledError())])
|
|
self.assertFalse(connection._shouldRetry(b"GET", exception, None))
|
|
|
|
def test_retryIfFailedDueToNonCancelException(self):
|
|
"""
|
|
If a request failed with L{ResponseNeverReceived} due to some
|
|
arbitrary exception, C{_shouldRetry} returns C{True} to indicate the
|
|
request should be retried.
|
|
"""
|
|
pool = client.HTTPConnectionPool(None)
|
|
connection = client._RetryingHTTP11ClientProtocol(None, pool)
|
|
self.assertTrue(
|
|
connection._shouldRetry(
|
|
b"GET", ResponseNeverReceived([Failure(Exception())]), None
|
|
)
|
|
)
|
|
|
|
def test_wrappedOnPersistentReturned(self):
|
|
"""
|
|
If L{client.HTTPConnectionPool.getConnection} returns a previously
|
|
cached connection, it will get wrapped in a
|
|
L{client._RetryingHTTP11ClientProtocol}.
|
|
"""
|
|
pool = client.HTTPConnectionPool(Clock())
|
|
|
|
# Add a connection to the cache:
|
|
protocol = StubHTTPProtocol()
|
|
protocol.makeConnection(StringTransport())
|
|
pool._putConnection(123, protocol)
|
|
|
|
# Retrieve it, it should come back wrapped in a
|
|
# _RetryingHTTP11ClientProtocol:
|
|
d = pool.getConnection(123, DummyEndpoint())
|
|
|
|
def gotConnection(connection):
|
|
self.assertIsInstance(connection, client._RetryingHTTP11ClientProtocol)
|
|
self.assertIdentical(connection._clientProtocol, protocol)
|
|
|
|
return d.addCallback(gotConnection)
|
|
|
|
def test_notWrappedOnNewReturned(self):
|
|
"""
|
|
If L{client.HTTPConnectionPool.getConnection} returns a new
|
|
connection, it will be returned as is.
|
|
"""
|
|
pool = client.HTTPConnectionPool(None)
|
|
d = pool.getConnection(123, DummyEndpoint())
|
|
|
|
def gotConnection(connection):
|
|
# Don't want to use isinstance since potentially the wrapper might
|
|
# subclass it at some point:
|
|
self.assertIdentical(connection.__class__, HTTP11ClientProtocol)
|
|
|
|
return d.addCallback(gotConnection)
|
|
|
|
def retryAttempt(self, willWeRetry):
|
|
"""
|
|
Fail a first request, possibly retrying depending on argument.
|
|
"""
|
|
protocols = []
|
|
|
|
def newProtocol():
|
|
protocol = StubHTTPProtocol()
|
|
protocols.append(protocol)
|
|
return defer.succeed(protocol)
|
|
|
|
bodyProducer = object()
|
|
request = client.Request(b"FOO", b"/", Headers(), bodyProducer, persistent=True)
|
|
newProtocol()
|
|
protocol = protocols[0]
|
|
retrier = client._RetryingHTTP11ClientProtocol(protocol, newProtocol)
|
|
|
|
def _shouldRetry(m, e, bp):
|
|
self.assertEqual(m, b"FOO")
|
|
self.assertIdentical(bp, bodyProducer)
|
|
self.assertIsInstance(e, (RequestNotSent, ResponseNeverReceived))
|
|
return willWeRetry
|
|
|
|
retrier._shouldRetry = _shouldRetry
|
|
|
|
d = retrier.request(request)
|
|
|
|
# So far, one request made:
|
|
self.assertEqual(len(protocols), 1)
|
|
self.assertEqual(len(protocols[0].requests), 1)
|
|
|
|
# Fail the first request:
|
|
protocol.requests[0][1].errback(RequestNotSent())
|
|
return d, protocols
|
|
|
|
def test_retryIfShouldRetryReturnsTrue(self):
|
|
"""
|
|
L{client._RetryingHTTP11ClientProtocol} retries when
|
|
L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{True}.
|
|
"""
|
|
d, protocols = self.retryAttempt(True)
|
|
# We retried!
|
|
self.assertEqual(len(protocols), 2)
|
|
response = object()
|
|
protocols[1].requests[0][1].callback(response)
|
|
return d.addCallback(self.assertIdentical, response)
|
|
|
|
def test_dontRetryIfShouldRetryReturnsFalse(self):
|
|
"""
|
|
L{client._RetryingHTTP11ClientProtocol} does not retry when
|
|
L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{False}.
|
|
"""
|
|
d, protocols = self.retryAttempt(False)
|
|
# We did not retry:
|
|
self.assertEqual(len(protocols), 1)
|
|
return self.assertFailure(d, RequestNotSent)
|
|
|
|
def test_onlyRetryWithoutBody(self):
|
|
"""
|
|
L{_RetryingHTTP11ClientProtocol} only retries queries that don't have
|
|
a body.
|
|
|
|
This is an implementation restriction; if the restriction is fixed,
|
|
this test should be removed and PUT added to list of methods that
|
|
support retries.
|
|
"""
|
|
pool = client.HTTPConnectionPool(None)
|
|
connection = client._RetryingHTTP11ClientProtocol(None, pool)
|
|
self.assertTrue(connection._shouldRetry(b"GET", RequestNotSent(), None))
|
|
self.assertFalse(connection._shouldRetry(b"GET", RequestNotSent(), object()))
|
|
|
|
def test_onlyRetryOnce(self):
|
|
"""
|
|
If a L{client._RetryingHTTP11ClientProtocol} fails more than once on
|
|
an idempotent query before a response is received, it will not retry.
|
|
"""
|
|
d, protocols = self.retryAttempt(True)
|
|
self.assertEqual(len(protocols), 2)
|
|
# Fail the second request too:
|
|
protocols[1].requests[0][1].errback(ResponseNeverReceived([]))
|
|
# We didn't retry again:
|
|
self.assertEqual(len(protocols), 2)
|
|
return self.assertFailure(d, ResponseNeverReceived)
|
|
|
|
def test_dontRetryIfRetryAutomaticallyFalse(self):
|
|
"""
|
|
If L{HTTPConnectionPool.retryAutomatically} is set to C{False}, don't
|
|
wrap connections with retrying logic.
|
|
"""
|
|
pool = client.HTTPConnectionPool(Clock())
|
|
pool.retryAutomatically = False
|
|
|
|
# Add a connection to the cache:
|
|
protocol = StubHTTPProtocol()
|
|
protocol.makeConnection(StringTransport())
|
|
pool._putConnection(123, protocol)
|
|
|
|
# Retrieve it, it should come back unwrapped:
|
|
d = pool.getConnection(123, DummyEndpoint())
|
|
|
|
def gotConnection(connection):
|
|
self.assertIdentical(connection, protocol)
|
|
|
|
return d.addCallback(gotConnection)
|
|
|
|
def test_retryWithNewConnection(self):
|
|
"""
|
|
L{client.HTTPConnectionPool} creates
|
|
{client._RetryingHTTP11ClientProtocol} with a new connection factory
|
|
method that creates a new connection using the same key and endpoint
|
|
as the wrapped connection.
|
|
"""
|
|
pool = client.HTTPConnectionPool(Clock())
|
|
key = 123
|
|
endpoint = DummyEndpoint()
|
|
newConnections = []
|
|
|
|
# Override the pool's _newConnection:
|
|
def newConnection(k, e):
|
|
newConnections.append((k, e))
|
|
|
|
pool._newConnection = newConnection
|
|
|
|
# Add a connection to the cache:
|
|
protocol = StubHTTPProtocol()
|
|
protocol.makeConnection(StringTransport())
|
|
pool._putConnection(key, protocol)
|
|
|
|
# Retrieve it, it should come back wrapped in a
|
|
# _RetryingHTTP11ClientProtocol:
|
|
d = pool.getConnection(key, endpoint)
|
|
|
|
def gotConnection(connection):
|
|
self.assertIsInstance(connection, client._RetryingHTTP11ClientProtocol)
|
|
self.assertIdentical(connection._clientProtocol, protocol)
|
|
# Verify that the _newConnection method on retrying connection
|
|
# calls _newConnection on the pool:
|
|
self.assertEqual(newConnections, [])
|
|
connection._newConnection()
|
|
self.assertEqual(len(newConnections), 1)
|
|
self.assertEqual(newConnections[0][0], key)
|
|
self.assertIdentical(newConnections[0][1], endpoint)
|
|
|
|
return d.addCallback(gotConnection)
|
|
|
|
|
|
class CookieTestsMixin:
|
|
"""
|
|
Mixin for unit tests dealing with cookies.
|
|
"""
|
|
|
|
def addCookies(
|
|
self, cookieJar: CookieJar, uri: bytes, cookies: list[bytes]
|
|
) -> tuple[client._FakeStdlibRequest, client._FakeStdlibResponse]:
|
|
"""
|
|
Add a cookie to a cookie jar.
|
|
"""
|
|
response = client._FakeStdlibResponse(
|
|
client.Response(
|
|
(b"HTTP", 1, 1),
|
|
200,
|
|
b"OK",
|
|
Headers({b"Set-Cookie": cookies}),
|
|
None,
|
|
)
|
|
)
|
|
request = client._FakeStdlibRequest(uri)
|
|
cookieJar.extract_cookies(response, request)
|
|
return request, response
|
|
|
|
|
|
class CookieJarTests(TestCase, CookieTestsMixin):
|
|
"""
|
|
Tests for L{twisted.web.client._FakeStdlibResponse} and
|
|
L{twisted.web.client._FakeStdlibRequest}'s interactions with L{CookieJar}
|
|
instances.
|
|
"""
|
|
|
|
def makeCookieJar(
|
|
self,
|
|
) -> tuple[CookieJar, tuple[client._FakeStdlibRequest, client._FakeStdlibResponse]]:
|
|
"""
|
|
@return: a L{CookieJar} with some sample cookies
|
|
"""
|
|
cookieJar = CookieJar()
|
|
reqres = self.addCookies(
|
|
cookieJar,
|
|
b"http://example.com:1234/foo?bar",
|
|
[b"foo=1; cow=moo; Path=/foo; Comment=hello", b"bar=2; Comment=goodbye"],
|
|
)
|
|
return cookieJar, reqres
|
|
|
|
def test_extractCookies(self) -> None:
|
|
"""
|
|
L{CookieJar.extract_cookies} extracts cookie information from our
|
|
stdlib-compatibility wrappers, L{client._FakeStdlibRequest} and
|
|
L{client._FakeStdlibResponse}.
|
|
"""
|
|
jar = self.makeCookieJar()[0]
|
|
cookies = {c.name: c for c in jar}
|
|
|
|
cookie = cookies["foo"]
|
|
self.assertEqual(cookie.version, 0)
|
|
self.assertEqual(cookie.name, "foo")
|
|
self.assertEqual(cookie.value, "1")
|
|
self.assertEqual(cookie.path, "/foo")
|
|
self.assertEqual(cookie.comment, "hello")
|
|
self.assertEqual(cookie.get_nonstandard_attr("cow"), "moo")
|
|
|
|
cookie = cookies["bar"]
|
|
self.assertEqual(cookie.version, 0)
|
|
self.assertEqual(cookie.name, "bar")
|
|
self.assertEqual(cookie.value, "2")
|
|
self.assertEqual(cookie.path, "/")
|
|
self.assertEqual(cookie.comment, "goodbye")
|
|
self.assertIdentical(cookie.get_nonstandard_attr("cow"), None)
|
|
|
|
def test_sendCookie(self) -> None:
|
|
"""
|
|
L{CookieJar.add_cookie_header} adds a cookie header to a Twisted
|
|
request via our L{client._FakeStdlibRequest} wrapper.
|
|
"""
|
|
jar, (request, response) = self.makeCookieJar()
|
|
|
|
self.assertIdentical(request.get_header("Cookie", None), None)
|
|
|
|
jar.add_cookie_header(request)
|
|
self.assertEqual(
|
|
list(request._twistedHeaders.getAllRawHeaders()),
|
|
[(b"Cookie", [b"foo=1; bar=2"])],
|
|
)
|
|
|
|
|
|
class CookieAgentTests(
|
|
TestCase, CookieTestsMixin, FakeReactorAndConnectMixin, AgentTestsMixin
|
|
):
|
|
"""
|
|
Tests for L{twisted.web.client.CookieAgent}.
|
|
"""
|
|
|
|
def makeAgent(self):
|
|
"""
|
|
@return: a new L{twisted.web.client.CookieAgent}
|
|
"""
|
|
return client.CookieAgent(
|
|
self.buildAgentForWrapperTest(self.reactor), CookieJar()
|
|
)
|
|
|
|
def setUp(self):
|
|
self.reactor = self.createReactor()
|
|
|
|
def test_emptyCookieJarRequest(self):
|
|
"""
|
|
L{CookieAgent.request} does not insert any C{'Cookie'} header into the
|
|
L{Request} object if there is no cookie in the cookie jar for the URI
|
|
being requested. Cookies are extracted from the response and stored in
|
|
the cookie jar.
|
|
"""
|
|
cookieJar = CookieJar()
|
|
self.assertEqual(list(cookieJar), [])
|
|
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
cookieAgent = client.CookieAgent(agent, cookieJar)
|
|
d = cookieAgent.request(b"GET", b"http://example.com:1234/foo?bar")
|
|
|
|
def _checkCookie(ignored):
|
|
cookies = list(cookieJar)
|
|
self.assertEqual(len(cookies), 1)
|
|
self.assertEqual(cookies[0].name, "foo")
|
|
self.assertEqual(cookies[0].value, "1")
|
|
|
|
d.addCallback(_checkCookie)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertIdentical(req.headers.getRawHeaders(b"cookie"), None)
|
|
|
|
resp = client.Response(
|
|
(b"HTTP", 1, 1),
|
|
200,
|
|
b"OK",
|
|
Headers(
|
|
{
|
|
b"Set-Cookie": [
|
|
b"foo=1",
|
|
]
|
|
}
|
|
),
|
|
None,
|
|
)
|
|
res.callback(resp)
|
|
|
|
return d
|
|
|
|
def test_leaveExistingCookieHeader(self) -> None:
|
|
"""
|
|
L{CookieAgent.request} will not insert a C{'Cookie'} header into the
|
|
L{Request} object when there is already a C{'Cookie'} header in the
|
|
request headers parameter.
|
|
"""
|
|
uri = b"http://example.com:1234/foo?bar"
|
|
cookie = b"foo=1"
|
|
|
|
cookieJar = CookieJar()
|
|
self.addCookies(cookieJar, uri, [cookie])
|
|
self.assertEqual(len(list(cookieJar)), 1)
|
|
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
cookieAgent = client.CookieAgent(agent, cookieJar)
|
|
cookieAgent.request(b"GET", uri, Headers({"cookie": ["already-set"]}))
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertEqual(req.headers.getRawHeaders(b"cookie"), [b"already-set"])
|
|
|
|
def test_requestWithCookie(self):
|
|
"""
|
|
L{CookieAgent.request} inserts a C{'Cookie'} header into the L{Request}
|
|
object when there is a cookie matching the request URI in the cookie
|
|
jar.
|
|
"""
|
|
uri = b"http://example.com:1234/foo?bar"
|
|
cookie = b"foo=1"
|
|
|
|
cookieJar = CookieJar()
|
|
self.addCookies(cookieJar, uri, [cookie])
|
|
self.assertEqual(len(list(cookieJar)), 1)
|
|
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
cookieAgent = client.CookieAgent(agent, cookieJar)
|
|
cookieAgent.request(b"GET", uri)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertEqual(req.headers.getRawHeaders(b"cookie"), [cookie])
|
|
|
|
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
|
|
def test_secureCookie(self):
|
|
"""
|
|
L{CookieAgent} is able to handle secure cookies, ie cookies which
|
|
should only be handled over https.
|
|
"""
|
|
uri = b"https://example.com:1234/foo?bar"
|
|
cookie = b"foo=1;secure"
|
|
|
|
cookieJar = CookieJar()
|
|
self.addCookies(cookieJar, uri, [cookie])
|
|
self.assertEqual(len(list(cookieJar)), 1)
|
|
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
cookieAgent = client.CookieAgent(agent, cookieJar)
|
|
cookieAgent.request(b"GET", uri)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertEqual(req.headers.getRawHeaders(b"cookie"), [b"foo=1"])
|
|
|
|
def test_secureCookieOnInsecureConnection(self):
|
|
"""
|
|
If a cookie is setup as secure, it won't be sent with the request if
|
|
it's not over HTTPS.
|
|
"""
|
|
uri = b"http://example.com/foo?bar"
|
|
cookie = b"foo=1;secure"
|
|
|
|
cookieJar = CookieJar()
|
|
self.addCookies(cookieJar, uri, [cookie])
|
|
self.assertEqual(len(list(cookieJar)), 1)
|
|
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
cookieAgent = client.CookieAgent(agent, cookieJar)
|
|
cookieAgent.request(b"GET", uri)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertIdentical(None, req.headers.getRawHeaders(b"cookie"))
|
|
|
|
def test_portCookie(self):
|
|
"""
|
|
L{CookieAgent} supports cookies which enforces the port number they
|
|
need to be transferred upon.
|
|
"""
|
|
uri = b"http://example.com:1234/foo?bar"
|
|
cookie = b"foo=1;port=1234"
|
|
|
|
cookieJar = CookieJar()
|
|
self.addCookies(cookieJar, uri, [cookie])
|
|
self.assertEqual(len(list(cookieJar)), 1)
|
|
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
cookieAgent = client.CookieAgent(agent, cookieJar)
|
|
cookieAgent.request(b"GET", uri)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
self.assertEqual(req.headers.getRawHeaders(b"cookie"), [b"foo=1"])
|
|
|
|
def test_portCookieOnWrongPort(self):
|
|
"""
|
|
When creating a cookie with a port directive, it won't be added to the
|
|
L{cookie.CookieJar} if the URI is on a different port.
|
|
"""
|
|
uri = b"http://example.com:4567/foo?bar"
|
|
cookie = b"foo=1;port=1234"
|
|
|
|
cookieJar = CookieJar()
|
|
self.addCookies(cookieJar, uri, [cookie])
|
|
self.assertEqual(len(list(cookieJar)), 0)
|
|
|
|
|
|
class Decoder1(proxyForInterface(IResponse)): # type: ignore[misc]
|
|
"""
|
|
A test decoder to be used by L{client.ContentDecoderAgent} tests.
|
|
"""
|
|
|
|
|
|
class Decoder2(Decoder1):
|
|
"""
|
|
A test decoder to be used by L{client.ContentDecoderAgent} tests.
|
|
"""
|
|
|
|
|
|
class ContentDecoderAgentTests(TestCase, FakeReactorAndConnectMixin, AgentTestsMixin):
|
|
"""
|
|
Tests for L{client.ContentDecoderAgent}.
|
|
"""
|
|
|
|
def makeAgent(self):
|
|
"""
|
|
@return: a new L{twisted.web.client.ContentDecoderAgent}
|
|
"""
|
|
return client.ContentDecoderAgent(self.agent, [])
|
|
|
|
def setUp(self):
|
|
"""
|
|
Create an L{Agent} wrapped around a fake reactor.
|
|
"""
|
|
self.reactor = self.createReactor()
|
|
self.agent = self.buildAgentForWrapperTest(self.reactor)
|
|
|
|
def test_acceptHeaders(self):
|
|
"""
|
|
L{client.ContentDecoderAgent} sets the I{Accept-Encoding} header to the
|
|
names of the available decoder objects.
|
|
"""
|
|
agent = client.ContentDecoderAgent(
|
|
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
|
|
)
|
|
|
|
agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
protocol = self.protocol
|
|
|
|
self.assertEqual(len(protocol.requests), 1)
|
|
req, res = protocol.requests.pop()
|
|
self.assertEqual(
|
|
req.headers.getRawHeaders(b"accept-encoding"), [b"decoder1,decoder2"]
|
|
)
|
|
|
|
def test_existingHeaders(self):
|
|
"""
|
|
If there are existing I{Accept-Encoding} fields,
|
|
L{client.ContentDecoderAgent} creates a new field for the decoders it
|
|
knows about.
|
|
"""
|
|
headers = http_headers.Headers(
|
|
{b"foo": [b"bar"], b"accept-encoding": [b"fizz"]}
|
|
)
|
|
agent = client.ContentDecoderAgent(
|
|
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
|
|
)
|
|
agent.request(b"GET", b"http://example.com/foo", headers=headers)
|
|
|
|
protocol = self.protocol
|
|
|
|
self.assertEqual(len(protocol.requests), 1)
|
|
req, res = protocol.requests.pop()
|
|
self.assertEqual(
|
|
list(sorted(req.headers.getAllRawHeaders())),
|
|
[
|
|
(b"Accept-Encoding", [b"fizz", b"decoder1,decoder2"]),
|
|
(b"Foo", [b"bar"]),
|
|
(b"Host", [b"example.com"]),
|
|
],
|
|
)
|
|
|
|
def test_plainEncodingResponse(self):
|
|
"""
|
|
If the response is not encoded despited the request I{Accept-Encoding}
|
|
headers, L{client.ContentDecoderAgent} simply forwards the response.
|
|
"""
|
|
agent = client.ContentDecoderAgent(
|
|
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
|
|
)
|
|
deferred = agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", http_headers.Headers(), None)
|
|
res.callback(response)
|
|
|
|
return deferred.addCallback(self.assertIdentical, response)
|
|
|
|
def test_unsupportedEncoding(self):
|
|
"""
|
|
If an encoding unknown to the L{client.ContentDecoderAgent} is found,
|
|
the response is unchanged.
|
|
"""
|
|
agent = client.ContentDecoderAgent(
|
|
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
|
|
)
|
|
deferred = agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers(
|
|
{b"foo": [b"bar"], b"content-encoding": [b"fizz"]}
|
|
)
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
return deferred.addCallback(self.assertIdentical, response)
|
|
|
|
def test_unknownEncoding(self):
|
|
"""
|
|
When L{client.ContentDecoderAgent} encounters a decoder it doesn't know
|
|
about, it stops decoding even if another encoding is known afterwards.
|
|
"""
|
|
agent = client.ContentDecoderAgent(
|
|
self.agent, [(b"decoder1", Decoder1), (b"decoder2", Decoder2)]
|
|
)
|
|
deferred = agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers(
|
|
{b"foo": [b"bar"], b"content-encoding": [b"decoder1,fizz,decoder2"]}
|
|
)
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
def check(result):
|
|
self.assertNotIdentical(response, result)
|
|
self.assertIsInstance(result, Decoder2)
|
|
self.assertEqual(
|
|
[b"decoder1,fizz"], result.headers.getRawHeaders(b"content-encoding")
|
|
)
|
|
|
|
return deferred.addCallback(check)
|
|
|
|
|
|
class SimpleAgentProtocol(Protocol):
|
|
"""
|
|
A L{Protocol} to be used with an L{client.Agent} to receive data.
|
|
|
|
@ivar finished: L{Deferred} firing when C{connectionLost} is called.
|
|
|
|
@ivar made: L{Deferred} firing when C{connectionMade} is called.
|
|
|
|
@ivar received: C{list} of received data.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.made = Deferred()
|
|
self.finished = Deferred()
|
|
self.received = []
|
|
|
|
def connectionMade(self):
|
|
self.made.callback(None)
|
|
|
|
def connectionLost(self, reason):
|
|
self.finished.callback(None)
|
|
|
|
def dataReceived(self, data):
|
|
self.received.append(data)
|
|
|
|
|
|
class ContentDecoderAgentWithGzipTests(TestCase, FakeReactorAndConnectMixin):
|
|
def setUp(self):
|
|
"""
|
|
Create an L{Agent} wrapped around a fake reactor.
|
|
"""
|
|
self.reactor = self.createReactor()
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
self.agent = client.ContentDecoderAgent(agent, [(b"gzip", client.GzipDecoder)])
|
|
|
|
def test_gzipEncodingResponse(self):
|
|
"""
|
|
If the response has a C{gzip} I{Content-Encoding} header,
|
|
L{GzipDecoder} wraps the response to return uncompressed data to the
|
|
user.
|
|
"""
|
|
deferred = self.agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers(
|
|
{b"foo": [b"bar"], b"content-encoding": [b"gzip"]}
|
|
)
|
|
transport = StringTransport()
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
|
|
response.length = 12
|
|
res.callback(response)
|
|
|
|
compressor = zlib.compressobj(2, zlib.DEFLATED, 16 + zlib.MAX_WBITS)
|
|
data = (
|
|
compressor.compress(b"x" * 6)
|
|
+ compressor.compress(b"y" * 4)
|
|
+ compressor.flush()
|
|
)
|
|
|
|
def checkResponse(result):
|
|
self.assertNotIdentical(result, response)
|
|
self.assertEqual(result.version, (b"HTTP", 1, 1))
|
|
self.assertEqual(result.code, 200)
|
|
self.assertEqual(result.phrase, b"OK")
|
|
self.assertEqual(
|
|
list(result.headers.getAllRawHeaders()), [(b"Foo", [b"bar"])]
|
|
)
|
|
self.assertEqual(result.length, UNKNOWN_LENGTH)
|
|
self.assertRaises(AttributeError, getattr, result, "unknown")
|
|
|
|
response._bodyDataReceived(data[:5])
|
|
response._bodyDataReceived(data[5:])
|
|
response._bodyDataFinished()
|
|
|
|
protocol = SimpleAgentProtocol()
|
|
result.deliverBody(protocol)
|
|
|
|
self.assertEqual(protocol.received, [b"x" * 6 + b"y" * 4])
|
|
return defer.gatherResults([protocol.made, protocol.finished])
|
|
|
|
deferred.addCallback(checkResponse)
|
|
|
|
return deferred
|
|
|
|
def test_brokenContent(self):
|
|
"""
|
|
If the data received by the L{GzipDecoder} isn't valid gzip-compressed
|
|
data, the call to C{deliverBody} fails with a C{zlib.error}.
|
|
"""
|
|
deferred = self.agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers(
|
|
{b"foo": [b"bar"], b"content-encoding": [b"gzip"]}
|
|
)
|
|
transport = StringTransport()
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
|
|
response.length = 12
|
|
res.callback(response)
|
|
|
|
data = b"not gzipped content"
|
|
|
|
def checkResponse(result):
|
|
response._bodyDataReceived(data)
|
|
|
|
result.deliverBody(Protocol())
|
|
|
|
deferred.addCallback(checkResponse)
|
|
self.assertFailure(deferred, client.ResponseFailed)
|
|
|
|
def checkFailure(error):
|
|
error.reasons[0].trap(zlib.error)
|
|
self.assertIsInstance(error.response, Response)
|
|
|
|
return deferred.addCallback(checkFailure)
|
|
|
|
def test_flushData(self):
|
|
"""
|
|
When the connection with the server is lost, the gzip protocol calls
|
|
C{flush} on the zlib decompressor object to get uncompressed data which
|
|
may have been buffered.
|
|
"""
|
|
|
|
class decompressobj:
|
|
def __init__(self, wbits):
|
|
pass
|
|
|
|
def decompress(self, data):
|
|
return b"x"
|
|
|
|
def flush(self):
|
|
return b"y"
|
|
|
|
oldDecompressObj = zlib.decompressobj
|
|
zlib.decompressobj = decompressobj
|
|
self.addCleanup(setattr, zlib, "decompressobj", oldDecompressObj)
|
|
|
|
deferred = self.agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers({b"content-encoding": [b"gzip"]})
|
|
transport = StringTransport()
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
|
|
res.callback(response)
|
|
|
|
def checkResponse(result):
|
|
response._bodyDataReceived(b"data")
|
|
response._bodyDataFinished()
|
|
|
|
protocol = SimpleAgentProtocol()
|
|
result.deliverBody(protocol)
|
|
|
|
self.assertEqual(protocol.received, [b"x", b"y"])
|
|
return defer.gatherResults([protocol.made, protocol.finished])
|
|
|
|
deferred.addCallback(checkResponse)
|
|
|
|
return deferred
|
|
|
|
def test_flushError(self):
|
|
"""
|
|
If the C{flush} call in C{connectionLost} fails, the C{zlib.error}
|
|
exception is caught and turned into a L{ResponseFailed}.
|
|
"""
|
|
|
|
class decompressobj:
|
|
def __init__(self, wbits):
|
|
pass
|
|
|
|
def decompress(self, data):
|
|
return b"x"
|
|
|
|
def flush(self):
|
|
raise zlib.error()
|
|
|
|
oldDecompressObj = zlib.decompressobj
|
|
zlib.decompressobj = decompressobj
|
|
self.addCleanup(setattr, zlib, "decompressobj", oldDecompressObj)
|
|
|
|
deferred = self.agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers({b"content-encoding": [b"gzip"]})
|
|
transport = StringTransport()
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, transport)
|
|
res.callback(response)
|
|
|
|
def checkResponse(result):
|
|
response._bodyDataReceived(b"data")
|
|
response._bodyDataFinished()
|
|
|
|
protocol = SimpleAgentProtocol()
|
|
result.deliverBody(protocol)
|
|
|
|
self.assertEqual(protocol.received, [b"x", b"y"])
|
|
return defer.gatherResults([protocol.made, protocol.finished])
|
|
|
|
deferred.addCallback(checkResponse)
|
|
|
|
self.assertFailure(deferred, client.ResponseFailed)
|
|
|
|
def checkFailure(error):
|
|
error.reasons[1].trap(zlib.error)
|
|
self.assertIsInstance(error.response, Response)
|
|
|
|
return deferred.addCallback(checkFailure)
|
|
|
|
|
|
class ProxyAgentTests(TestCase, FakeReactorAndConnectMixin, AgentTestsMixin):
|
|
"""
|
|
Tests for L{client.ProxyAgent}.
|
|
"""
|
|
|
|
def makeAgent(self):
|
|
"""
|
|
@return: a new L{twisted.web.client.ProxyAgent}
|
|
"""
|
|
return client.ProxyAgent(
|
|
TCP4ClientEndpoint(self.reactor, "127.0.0.1", 1234), self.reactor
|
|
)
|
|
|
|
def setUp(self):
|
|
self.reactor = self.createReactor()
|
|
self.agent = client.ProxyAgent(
|
|
TCP4ClientEndpoint(self.reactor, "bar", 5678), self.reactor
|
|
)
|
|
oldEndpoint = self.agent._proxyEndpoint
|
|
self.agent._proxyEndpoint = self.StubEndpoint(oldEndpoint, self)
|
|
|
|
def test_nonBytesMethod(self):
|
|
"""
|
|
L{ProxyAgent.request} raises L{TypeError} when the C{method} argument
|
|
isn't L{bytes}.
|
|
"""
|
|
self.assertRaises(TypeError, self.agent.request, "GET", b"http://foo.example/")
|
|
|
|
def test_proxyRequest(self):
|
|
"""
|
|
L{client.ProxyAgent} issues an HTTP request against the proxy, with the
|
|
full URI as path, when C{request} is called.
|
|
"""
|
|
headers = http_headers.Headers({b"foo": [b"bar"]})
|
|
# Just going to check the body for identity, so it doesn't need to be
|
|
# real.
|
|
body = object()
|
|
self.agent.request(b"GET", b"http://example.com:1234/foo?bar", headers, body)
|
|
|
|
host, port, factory = self.reactor.tcpClients.pop()[:3]
|
|
self.assertEqual(host, "bar")
|
|
self.assertEqual(port, 5678)
|
|
|
|
self.assertIsInstance(factory._wrappedFactory, client._HTTP11ClientFactory)
|
|
|
|
protocol = self.protocol
|
|
|
|
# The request should be issued.
|
|
self.assertEqual(len(protocol.requests), 1)
|
|
req, res = protocol.requests.pop()
|
|
self.assertIsInstance(req, Request)
|
|
self.assertEqual(req.method, b"GET")
|
|
self.assertEqual(req.uri, b"http://example.com:1234/foo?bar")
|
|
self.assertEqual(
|
|
req.headers,
|
|
http_headers.Headers({b"foo": [b"bar"], b"host": [b"example.com:1234"]}),
|
|
)
|
|
self.assertIdentical(req.bodyProducer, body)
|
|
|
|
def test_nonPersistent(self):
|
|
"""
|
|
C{ProxyAgent} connections are not persistent by default.
|
|
"""
|
|
self.assertEqual(self.agent._pool.persistent, False)
|
|
|
|
def test_connectUsesConnectionPool(self):
|
|
"""
|
|
When a connection is made by the C{ProxyAgent}, it uses its pool's
|
|
C{getConnection} method to do so, with the endpoint it was constructed
|
|
with and a key of C{("http-proxy", endpoint)}.
|
|
"""
|
|
endpoint = DummyEndpoint()
|
|
|
|
class DummyPool:
|
|
connected = False
|
|
persistent = False
|
|
|
|
def getConnection(this, key, ep):
|
|
this.connected = True
|
|
self.assertIdentical(ep, endpoint)
|
|
# The key is *not* tied to the final destination, but only to
|
|
# the address of the proxy, since that's where *we* are
|
|
# connecting:
|
|
self.assertEqual(key, ("http-proxy", endpoint))
|
|
return defer.succeed(StubHTTPProtocol())
|
|
|
|
pool = DummyPool()
|
|
agent = client.ProxyAgent(endpoint, self.reactor, pool=pool)
|
|
self.assertIdentical(pool, agent._pool)
|
|
|
|
agent.request(b"GET", b"http://foo/")
|
|
self.assertEqual(agent._pool.connected, True)
|
|
|
|
|
|
SENSITIVE_HEADERS = [
|
|
b"authorization",
|
|
b"cookie",
|
|
b"cookie2",
|
|
b"proxy-authorization",
|
|
b"www-authenticate",
|
|
]
|
|
|
|
|
|
class _RedirectAgentTestsMixin(testMixinClass):
|
|
"""
|
|
Test cases mixin for L{RedirectAgentTests} and
|
|
L{BrowserLikeRedirectAgentTests}.
|
|
"""
|
|
|
|
agent: IAgent
|
|
reactor: MemoryReactorClock
|
|
protocol: StubHTTPProtocol
|
|
|
|
def test_noRedirect(self):
|
|
"""
|
|
L{client.RedirectAgent} behaves like L{client.Agent} if the response
|
|
doesn't contain a redirect.
|
|
"""
|
|
deferred = self.agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers()
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
self.assertEqual(0, len(self.protocol.requests))
|
|
result = self.successResultOf(deferred)
|
|
self.assertIdentical(response, result)
|
|
self.assertIdentical(result.previousResponse, None)
|
|
|
|
def _testRedirectDefault(
|
|
self,
|
|
code: int,
|
|
crossScheme: bool = False,
|
|
crossDomain: bool = False,
|
|
crossPort: bool = False,
|
|
requestHeaders: Optional[Headers] = None,
|
|
) -> Request:
|
|
"""
|
|
When getting a redirect, L{client.RedirectAgent} follows the URL
|
|
specified in the L{Location} header field and make a new request.
|
|
|
|
@param code: HTTP status code.
|
|
"""
|
|
startDomain = b"example.com"
|
|
startScheme = b"https" if ssl is not None else b"http"
|
|
startPort = 80 if startScheme == b"http" else 443
|
|
self.agent.request(
|
|
b"GET", startScheme + b"://" + startDomain + b"/foo", headers=requestHeaders
|
|
)
|
|
|
|
host, port = self.reactor.tcpClients.pop()[:2]
|
|
self.assertEqual(EXAMPLE_COM_IP, host)
|
|
self.assertEqual(startPort, port)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
# If possible (i.e.: TLS support is present), run the test with a
|
|
# cross-scheme redirect to verify that the scheme is honored; if not,
|
|
# let's just make sure it works at all.
|
|
|
|
targetScheme = startScheme
|
|
targetDomain = startDomain
|
|
targetPort = startPort
|
|
|
|
if crossScheme:
|
|
if ssl is None:
|
|
raise SkipTest(
|
|
"Cross-scheme redirects can't be tested without TLS support."
|
|
)
|
|
targetScheme = b"https" if startScheme == b"http" else b"http"
|
|
targetPort = 443 if startPort == 80 else 80
|
|
|
|
portSyntax = b""
|
|
if crossPort:
|
|
targetPort = 8443
|
|
portSyntax = b":8443"
|
|
targetDomain = b"example.net" if crossDomain else startDomain
|
|
locationValue = targetScheme + b"://" + targetDomain + portSyntax + b"/bar"
|
|
headers = http_headers.Headers({b"location": [locationValue]})
|
|
response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
req2, res2 = self.protocol.requests.pop()
|
|
self.assertEqual(b"GET", req2.method)
|
|
self.assertEqual(b"/bar", req2.uri)
|
|
|
|
host, port = self.reactor.tcpClients.pop()[:2]
|
|
self.assertEqual(EXAMPLE_NET_IP if crossDomain else EXAMPLE_COM_IP, host)
|
|
self.assertEqual(targetPort, port)
|
|
return req2
|
|
|
|
def test_redirect301(self):
|
|
"""
|
|
L{client.RedirectAgent} follows redirects on status code 301.
|
|
"""
|
|
self._testRedirectDefault(301)
|
|
|
|
def test_redirect301Scheme(self):
|
|
"""
|
|
L{client.RedirectAgent} follows cross-scheme redirects.
|
|
"""
|
|
self._testRedirectDefault(
|
|
301,
|
|
crossScheme=True,
|
|
)
|
|
|
|
def test_redirect302(self):
|
|
"""
|
|
L{client.RedirectAgent} follows redirects on status code 302.
|
|
"""
|
|
self._testRedirectDefault(302)
|
|
|
|
def test_redirect307(self):
|
|
"""
|
|
L{client.RedirectAgent} follows redirects on status code 307.
|
|
"""
|
|
self._testRedirectDefault(307)
|
|
|
|
def test_redirect308(self):
|
|
"""
|
|
L{client.RedirectAgent} follows redirects on status code 308.
|
|
"""
|
|
self._testRedirectDefault(308)
|
|
|
|
def _sensitiveHeadersTest(
|
|
self, expectedHostHeader: bytes = b"example.com", **crossKwargs: bool
|
|
) -> None:
|
|
"""
|
|
L{client.RedirectAgent} scrubs sensitive headers when redirecting
|
|
between differing origins.
|
|
"""
|
|
sensitiveHeaderValues = {
|
|
b"authorization": [b"sensitive-authnz"],
|
|
b"cookie": [b"sensitive-cookie-data"],
|
|
b"cookie2": [b"sensitive-cookie2-data"],
|
|
b"proxy-authorization": [b"sensitive-proxy-auth"],
|
|
b"wWw-auThentiCate": [b"sensitive-authn"],
|
|
b"x-custom-sensitive": [b"sensitive-custom"],
|
|
}
|
|
otherHeaderValues = {b"x-random-header": [b"x-random-value"]}
|
|
allHeaders = Headers({**sensitiveHeaderValues, **otherHeaderValues})
|
|
redirected = self._testRedirectDefault(301, requestHeaders=allHeaders)
|
|
|
|
def normHeaders(headers: Headers) -> Dict[bytes, Sequence[bytes]]:
|
|
return {k.lower(): v for (k, v) in headers.getAllRawHeaders()}
|
|
|
|
sameOriginHeaders = normHeaders(redirected.headers)
|
|
self.assertEquals(
|
|
sameOriginHeaders,
|
|
{
|
|
b"host": [b"example.com"],
|
|
**normHeaders(allHeaders),
|
|
},
|
|
)
|
|
|
|
redirectedElsewhere = self._testRedirectDefault(
|
|
301,
|
|
**crossKwargs,
|
|
requestHeaders=Headers({**sensitiveHeaderValues, **otherHeaderValues}),
|
|
)
|
|
otherOriginHeaders = normHeaders(redirectedElsewhere.headers)
|
|
self.assertEquals(
|
|
otherOriginHeaders,
|
|
{
|
|
b"host": [expectedHostHeader],
|
|
**normHeaders(Headers(otherHeaderValues)),
|
|
},
|
|
)
|
|
|
|
def test_crossDomainHeaders(self) -> None:
|
|
"""
|
|
L{client.RedirectAgent} scrubs sensitive headers when redirecting
|
|
between differing domains.
|
|
"""
|
|
self._sensitiveHeadersTest(crossDomain=True, expectedHostHeader=b"example.net")
|
|
|
|
def test_crossPortHeaders(self) -> None:
|
|
"""
|
|
L{client.RedirectAgent} scrubs sensitive headers when redirecting
|
|
between differing ports.
|
|
"""
|
|
self._sensitiveHeadersTest(
|
|
crossPort=True, expectedHostHeader=b"example.com:8443"
|
|
)
|
|
|
|
def test_crossSchemeHeaders(self) -> None:
|
|
"""
|
|
L{client.RedirectAgent} scrubs sensitive headers when redirecting
|
|
between differing schemes.
|
|
"""
|
|
self._sensitiveHeadersTest(crossScheme=True)
|
|
|
|
def _testRedirectToGet(self, code, method):
|
|
"""
|
|
L{client.RedirectAgent} changes the method to I{GET} when getting
|
|
a redirect on a non-I{GET} request.
|
|
|
|
@param code: HTTP status code.
|
|
|
|
@param method: HTTP request method.
|
|
"""
|
|
self.agent.request(method, b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
|
|
response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
req2, res2 = self.protocol.requests.pop()
|
|
self.assertEqual(b"GET", req2.method)
|
|
self.assertEqual(b"/bar", req2.uri)
|
|
|
|
def test_redirect303(self):
|
|
"""
|
|
L{client.RedirectAgent} changes the method to I{GET} when getting a 303
|
|
redirect on a I{POST} request.
|
|
"""
|
|
self._testRedirectToGet(303, b"POST")
|
|
|
|
def test_noLocationField(self):
|
|
"""
|
|
If no L{Location} header field is found when getting a redirect,
|
|
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping a
|
|
L{error.RedirectWithNoLocation} exception.
|
|
"""
|
|
deferred = self.agent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers()
|
|
response = Response((b"HTTP", 1, 1), 301, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
fail = self.failureResultOf(deferred, client.ResponseFailed)
|
|
fail.value.reasons[0].trap(error.RedirectWithNoLocation)
|
|
self.assertEqual(b"http://example.com/foo", fail.value.reasons[0].value.uri)
|
|
self.assertEqual(301, fail.value.response.code)
|
|
|
|
def _testPageRedirectFailure(self, code, method):
|
|
"""
|
|
When getting a redirect on an unsupported request method,
|
|
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
|
|
a L{error.PageRedirect} exception.
|
|
|
|
@param code: HTTP status code.
|
|
|
|
@param method: HTTP request method.
|
|
"""
|
|
deferred = self.agent.request(method, b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers()
|
|
response = Response((b"HTTP", 1, 1), code, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
fail = self.failureResultOf(deferred, client.ResponseFailed)
|
|
fail.value.reasons[0].trap(error.PageRedirect)
|
|
self.assertEqual(
|
|
b"http://example.com/foo", fail.value.reasons[0].value.location
|
|
)
|
|
self.assertEqual(code, fail.value.response.code)
|
|
|
|
def test_307OnPost(self):
|
|
"""
|
|
When getting a 307 redirect on a I{POST} request,
|
|
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
|
|
a L{error.PageRedirect} exception.
|
|
"""
|
|
self._testPageRedirectFailure(307, b"POST")
|
|
|
|
def test_redirectLimit(self):
|
|
"""
|
|
If the limit of redirects specified to L{client.RedirectAgent} is
|
|
reached, the deferred fires with L{ResponseFailed} error wrapping
|
|
a L{InfiniteRedirection} exception.
|
|
"""
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
redirectAgent = client.RedirectAgent(agent, 1)
|
|
|
|
deferred = redirectAgent.request(b"GET", b"http://example.com/foo")
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
|
|
response = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
req2, res2 = self.protocol.requests.pop()
|
|
|
|
response2 = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
|
|
res2.callback(response2)
|
|
|
|
fail = self.failureResultOf(deferred, client.ResponseFailed)
|
|
|
|
fail.value.reasons[0].trap(error.InfiniteRedirection)
|
|
self.assertEqual(
|
|
b"http://example.com/foo", fail.value.reasons[0].value.location
|
|
)
|
|
self.assertEqual(302, fail.value.response.code)
|
|
|
|
def _testRedirectURI(self, uri, location, finalURI):
|
|
"""
|
|
When L{client.RedirectAgent} encounters a relative redirect I{URI}, it
|
|
is resolved against the request I{URI} before following the redirect.
|
|
|
|
@param uri: Request URI.
|
|
|
|
@param location: I{Location} header redirect URI.
|
|
|
|
@param finalURI: Expected final URI.
|
|
"""
|
|
self.agent.request(b"GET", uri)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers({b"location": [location]})
|
|
response = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
req2, res2 = self.protocol.requests.pop()
|
|
self.assertEqual(b"GET", req2.method)
|
|
self.assertEqual(finalURI, req2.absoluteURI)
|
|
|
|
def test_relativeURI(self):
|
|
"""
|
|
L{client.RedirectAgent} resolves and follows relative I{URI}s in
|
|
redirects, preserving query strings.
|
|
"""
|
|
self._testRedirectURI(
|
|
b"http://example.com/foo/bar", b"baz", b"http://example.com/foo/baz"
|
|
)
|
|
self._testRedirectURI(
|
|
b"http://example.com/foo/bar", b"/baz", b"http://example.com/baz"
|
|
)
|
|
self._testRedirectURI(
|
|
b"http://example.com/foo/bar", b"/baz?a", b"http://example.com/baz?a"
|
|
)
|
|
|
|
def test_relativeURIPreserveFragments(self):
|
|
"""
|
|
L{client.RedirectAgent} resolves and follows relative I{URI}s in
|
|
redirects, preserving fragments in way that complies with the HTTP 1.1
|
|
bis draft.
|
|
|
|
@see: U{https://tools.ietf.org/html/draft-ietf-httpbis-p2-semantics-22#section-7.1.2}
|
|
"""
|
|
self._testRedirectURI(
|
|
b"http://example.com/foo/bar#frag",
|
|
b"/baz?a",
|
|
b"http://example.com/baz?a#frag",
|
|
)
|
|
self._testRedirectURI(
|
|
b"http://example.com/foo/bar",
|
|
b"/baz?a#frag2",
|
|
b"http://example.com/baz?a#frag2",
|
|
)
|
|
|
|
def test_relativeURISchemeRelative(self):
|
|
"""
|
|
L{client.RedirectAgent} resolves and follows scheme relative I{URI}s in
|
|
redirects, replacing the hostname and port when required.
|
|
"""
|
|
self._testRedirectURI(
|
|
b"http://example.com/foo/bar", b"//foo.com/baz", b"http://foo.com/baz"
|
|
)
|
|
self._testRedirectURI(
|
|
b"http://example.com/foo/bar", b"//foo.com:81/baz", b"http://foo.com:81/baz"
|
|
)
|
|
|
|
def test_responseHistory(self):
|
|
"""
|
|
L{Response.response} references the previous L{Response} from
|
|
a redirect, or L{None} if there was no previous response.
|
|
"""
|
|
agent = self.buildAgentForWrapperTest(self.reactor)
|
|
redirectAgent = client.RedirectAgent(agent)
|
|
|
|
deferred = redirectAgent.request(b"GET", b"http://example.com/foo")
|
|
|
|
redirectReq, redirectRes = self.protocol.requests.pop()
|
|
|
|
headers = http_headers.Headers({b"location": [b"http://example.com/bar"]})
|
|
redirectResponse = Response((b"HTTP", 1, 1), 302, b"OK", headers, None)
|
|
redirectRes.callback(redirectResponse)
|
|
|
|
req, res = self.protocol.requests.pop()
|
|
|
|
response = Response((b"HTTP", 1, 1), 200, b"OK", headers, None)
|
|
res.callback(response)
|
|
|
|
finalResponse = self.successResultOf(deferred)
|
|
self.assertIdentical(finalResponse.previousResponse, redirectResponse)
|
|
self.assertIdentical(redirectResponse.previousResponse, None)
|
|
|
|
|
|
class RedirectAgentTests(
|
|
FakeReactorAndConnectMixin,
|
|
_RedirectAgentTestsMixin,
|
|
AgentTestsMixin,
|
|
runtimeTestCase,
|
|
):
|
|
"""
|
|
Tests for L{client.RedirectAgent}.
|
|
"""
|
|
|
|
def makeAgent(self):
|
|
"""
|
|
@return: a new L{twisted.web.client.RedirectAgent}
|
|
"""
|
|
return client.RedirectAgent(
|
|
self.buildAgentForWrapperTest(self.reactor),
|
|
sensitiveHeaderNames=[b"X-Custom-sensitive"],
|
|
)
|
|
|
|
def setUp(self):
|
|
self.reactor = self.createReactor()
|
|
self.agent = self.makeAgent()
|
|
|
|
def test_301OnPost(self):
|
|
"""
|
|
When getting a 301 redirect on a I{POST} request,
|
|
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
|
|
a L{error.PageRedirect} exception.
|
|
"""
|
|
self._testPageRedirectFailure(301, b"POST")
|
|
|
|
def test_302OnPost(self):
|
|
"""
|
|
When getting a 302 redirect on a I{POST} request,
|
|
L{client.RedirectAgent} fails with a L{ResponseFailed} error wrapping
|
|
a L{error.PageRedirect} exception.
|
|
"""
|
|
self._testPageRedirectFailure(302, b"POST")
|
|
|
|
|
|
class BrowserLikeRedirectAgentTests(
|
|
FakeReactorAndConnectMixin,
|
|
_RedirectAgentTestsMixin,
|
|
AgentTestsMixin,
|
|
runtimeTestCase,
|
|
):
|
|
"""
|
|
Tests for L{client.BrowserLikeRedirectAgent}.
|
|
"""
|
|
|
|
def makeAgent(self):
|
|
"""
|
|
@return: a new L{twisted.web.client.BrowserLikeRedirectAgent}
|
|
"""
|
|
return client.BrowserLikeRedirectAgent(
|
|
self.buildAgentForWrapperTest(self.reactor),
|
|
sensitiveHeaderNames=[b"x-Custom-sensitive"],
|
|
)
|
|
|
|
def setUp(self):
|
|
self.reactor = self.createReactor()
|
|
self.agent = self.makeAgent()
|
|
|
|
def test_redirectToGet301(self):
|
|
"""
|
|
L{client.BrowserLikeRedirectAgent} changes the method to I{GET} when
|
|
getting a 302 redirect on a I{POST} request.
|
|
"""
|
|
self._testRedirectToGet(301, b"POST")
|
|
|
|
def test_redirectToGet302(self):
|
|
"""
|
|
L{client.BrowserLikeRedirectAgent} changes the method to I{GET} when
|
|
getting a 302 redirect on a I{POST} request.
|
|
"""
|
|
self._testRedirectToGet(302, b"POST")
|
|
|
|
|
|
class AbortableStringTransport(StringTransport):
|
|
"""
|
|
A version of L{StringTransport} that supports C{abortConnection}.
|
|
"""
|
|
|
|
# This should be replaced by a common version in #6530.
|
|
aborting = False
|
|
|
|
def abortConnection(self):
|
|
"""
|
|
A testable version of the C{ITCPTransport.abortConnection} method.
|
|
|
|
Since this is a special case of closing the connection,
|
|
C{loseConnection} is also called.
|
|
"""
|
|
self.aborting = True
|
|
self.loseConnection()
|
|
|
|
|
|
class DummyResponse:
|
|
"""
|
|
Fake L{IResponse} for testing readBody that captures the protocol passed to
|
|
deliverBody and uses it to make a connection with a transport.
|
|
|
|
@ivar protocol: After C{deliverBody} is called, the protocol it was called
|
|
with.
|
|
|
|
@ivar transport: An instance created by calling C{transportFactory} which
|
|
is used by L{DummyResponse.protocol} to make a connection.
|
|
"""
|
|
|
|
code = 200
|
|
phrase = b"OK"
|
|
|
|
def __init__(self, headers=None, transportFactory=AbortableStringTransport):
|
|
"""
|
|
@param headers: The headers for this response. If L{None}, an empty
|
|
L{Headers} instance will be used.
|
|
@type headers: L{Headers}
|
|
|
|
@param transportFactory: A callable used to construct the transport.
|
|
"""
|
|
if headers is None:
|
|
headers = Headers()
|
|
self.headers = headers
|
|
self.transport = transportFactory()
|
|
|
|
def deliverBody(self, protocol):
|
|
"""
|
|
Record the given protocol and use it to make a connection with
|
|
L{DummyResponse.transport}.
|
|
"""
|
|
self.protocol = protocol
|
|
self.protocol.makeConnection(self.transport)
|
|
|
|
|
|
class AlreadyCompletedDummyResponse(DummyResponse):
|
|
"""
|
|
A dummy response that has already had its transport closed.
|
|
"""
|
|
|
|
def deliverBody(self, protocol):
|
|
"""
|
|
Make the connection, then remove the transport.
|
|
"""
|
|
self.protocol = protocol
|
|
self.protocol.makeConnection(self.transport)
|
|
self.protocol.transport = None
|
|
|
|
|
|
class ReadBodyTests(TestCase):
|
|
"""
|
|
Tests for L{client.readBody}
|
|
"""
|
|
|
|
def test_success(self):
|
|
"""
|
|
L{client.readBody} returns a L{Deferred} which fires with the complete
|
|
body of the L{IResponse} provider passed to it.
|
|
"""
|
|
response = DummyResponse()
|
|
d = client.readBody(response)
|
|
response.protocol.dataReceived(b"first")
|
|
response.protocol.dataReceived(b"second")
|
|
response.protocol.connectionLost(Failure(ResponseDone()))
|
|
self.assertEqual(self.successResultOf(d), b"firstsecond")
|
|
|
|
def test_cancel(self):
|
|
"""
|
|
When cancelling the L{Deferred} returned by L{client.readBody}, the
|
|
connection to the server will be aborted.
|
|
"""
|
|
response = DummyResponse()
|
|
deferred = client.readBody(response)
|
|
deferred.cancel()
|
|
self.failureResultOf(deferred, defer.CancelledError)
|
|
self.assertTrue(response.transport.aborting)
|
|
|
|
def test_withPotentialDataLoss(self):
|
|
"""
|
|
If the full body of the L{IResponse} passed to L{client.readBody} is
|
|
not definitely received, the L{Deferred} returned by L{client.readBody}
|
|
fires with a L{Failure} wrapping L{client.PartialDownloadError} with
|
|
the content that was received.
|
|
"""
|
|
response = DummyResponse()
|
|
d = client.readBody(response)
|
|
response.protocol.dataReceived(b"first")
|
|
response.protocol.dataReceived(b"second")
|
|
response.protocol.connectionLost(Failure(PotentialDataLoss()))
|
|
failure = self.failureResultOf(d)
|
|
failure.trap(client.PartialDownloadError)
|
|
self.assertEqual(
|
|
{
|
|
"status": failure.value.status,
|
|
"message": failure.value.message,
|
|
"body": failure.value.response,
|
|
},
|
|
{
|
|
"status": b"200",
|
|
"message": b"OK",
|
|
"body": b"firstsecond",
|
|
},
|
|
)
|
|
|
|
def test_otherErrors(self):
|
|
"""
|
|
If there is an exception other than L{client.PotentialDataLoss} while
|
|
L{client.readBody} is collecting the response body, the L{Deferred}
|
|
returned by {client.readBody} fires with that exception.
|
|
"""
|
|
response = DummyResponse()
|
|
d = client.readBody(response)
|
|
response.protocol.dataReceived(b"first")
|
|
response.protocol.connectionLost(Failure(ConnectionLost("mystery problem")))
|
|
reason = self.failureResultOf(d)
|
|
reason.trap(ConnectionLost)
|
|
self.assertEqual(reason.value.args, ("mystery problem",))
|
|
|
|
def test_deprecatedTransport(self):
|
|
"""
|
|
Calling L{client.readBody} with a transport that does not implement
|
|
L{twisted.internet.interfaces.ITCPTransport} produces a deprecation
|
|
warning, but no exception when cancelling.
|
|
"""
|
|
response = DummyResponse(transportFactory=StringTransport)
|
|
response.transport.abortConnection = None
|
|
d = self.assertWarns(
|
|
DeprecationWarning,
|
|
"Using readBody with a transport that does not have an "
|
|
"abortConnection method",
|
|
__file__,
|
|
lambda: client.readBody(response),
|
|
)
|
|
d.cancel()
|
|
self.failureResultOf(d, defer.CancelledError)
|
|
|
|
def test_deprecatedTransportNoWarning(self):
|
|
"""
|
|
Calling L{client.readBody} with a response that has already had its
|
|
transport closed (eg. for a very small request) will not trigger a
|
|
deprecation warning.
|
|
"""
|
|
response = AlreadyCompletedDummyResponse()
|
|
client.readBody(response)
|
|
|
|
warnings = self.flushWarnings()
|
|
self.assertEqual(len(warnings), 0)
|
|
|
|
|
|
@skipIf(not sslPresent, "SSL not present, cannot run SSL tests.")
|
|
class HostnameCachingHTTPSPolicyTests(TestCase):
|
|
def test_cacheIsUsed(self):
|
|
"""
|
|
Verify that the connection creator is added to the
|
|
policy's cache, and that it is reused on subsequent calls
|
|
to creatorForNetLoc.
|
|
|
|
"""
|
|
trustRoot = CustomOpenSSLTrustRoot()
|
|
wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
|
|
policy = HostnameCachingHTTPSPolicy(wrappedPolicy)
|
|
creator = policy.creatorForNetloc(b"foo", 1589)
|
|
self.assertTrue(trustRoot.called)
|
|
trustRoot.called = False
|
|
self.assertEquals(1, len(policy._cache))
|
|
connection = creator.clientConnectionForTLS(None)
|
|
self.assertIs(trustRoot.context, connection.get_context())
|
|
|
|
policy.creatorForNetloc(b"foo", 1589)
|
|
self.assertFalse(trustRoot.called)
|
|
|
|
def test_cacheRemovesOldest(self):
|
|
"""
|
|
Verify that when the cache is full, and a new entry is added,
|
|
the oldest entry is removed.
|
|
"""
|
|
trustRoot = CustomOpenSSLTrustRoot()
|
|
wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
|
|
policy = HostnameCachingHTTPSPolicy(wrappedPolicy)
|
|
for i in range(0, 20):
|
|
hostname = "host" + str(i)
|
|
policy.creatorForNetloc(hostname.encode("ascii"), 8675)
|
|
|
|
# Force host0, which was the first, to be the most recently used
|
|
host0 = "host0"
|
|
policy.creatorForNetloc(host0.encode("ascii"), 309)
|
|
self.assertIn(host0, policy._cache)
|
|
self.assertEquals(20, len(policy._cache))
|
|
|
|
hostn = "new"
|
|
policy.creatorForNetloc(hostn.encode("ascii"), 309)
|
|
|
|
host1 = "host1"
|
|
self.assertNotIn(host1, policy._cache)
|
|
self.assertEquals(20, len(policy._cache))
|
|
|
|
self.assertIn(hostn, policy._cache)
|
|
self.assertIn(host0, policy._cache)
|
|
|
|
# Accessing an item repeatedly does not corrupt the LRU.
|
|
for _ in range(20):
|
|
policy.creatorForNetloc(host0.encode("ascii"), 8675)
|
|
|
|
hostNPlus1 = "new1"
|
|
|
|
policy.creatorForNetloc(hostNPlus1.encode("ascii"), 800)
|
|
|
|
self.assertNotIn("host2", policy._cache)
|
|
self.assertEquals(20, len(policy._cache))
|
|
|
|
self.assertIn(hostNPlus1, policy._cache)
|
|
self.assertIn(hostn, policy._cache)
|
|
self.assertIn(host0, policy._cache)
|
|
|
|
def test_changeCacheSize(self):
|
|
"""
|
|
Verify that changing the cache size results in a policy that
|
|
respects the new cache size and not the default.
|
|
|
|
"""
|
|
trustRoot = CustomOpenSSLTrustRoot()
|
|
wrappedPolicy = BrowserLikePolicyForHTTPS(trustRoot=trustRoot)
|
|
policy = HostnameCachingHTTPSPolicy(wrappedPolicy, cacheSize=5)
|
|
for i in range(0, 5):
|
|
hostname = "host" + str(i)
|
|
policy.creatorForNetloc(hostname.encode("ascii"), 8675)
|
|
|
|
first = "host0"
|
|
self.assertIn(first, policy._cache)
|
|
self.assertEquals(5, len(policy._cache))
|
|
|
|
hostn = "new"
|
|
policy.creatorForNetloc(hostn.encode("ascii"), 309)
|
|
self.assertNotIn(first, policy._cache)
|
|
self.assertEquals(5, len(policy._cache))
|
|
|
|
self.assertIn(hostn, policy._cache)
|
|
|
|
|
|
class RequestMethodInjectionTests(
|
|
MethodInjectionTestsMixin,
|
|
SynchronousTestCase,
|
|
):
|
|
"""
|
|
Test L{client.Request} against HTTP method injections.
|
|
"""
|
|
|
|
def attemptRequestWithMaliciousMethod(self, method):
|
|
"""
|
|
Attempt a request with the provided method.
|
|
|
|
@param method: see L{MethodInjectionTestsMixin}
|
|
"""
|
|
client.Request(
|
|
method=method,
|
|
uri=b"http://twisted.invalid",
|
|
headers=http_headers.Headers(),
|
|
bodyProducer=None,
|
|
)
|
|
|
|
|
|
class RequestWriteToMethodInjectionTests(
|
|
MethodInjectionTestsMixin,
|
|
SynchronousTestCase,
|
|
):
|
|
"""
|
|
Test L{client.Request.writeTo} against HTTP method injections.
|
|
"""
|
|
|
|
def attemptRequestWithMaliciousMethod(self, method):
|
|
"""
|
|
Attempt a request with the provided method.
|
|
|
|
@param method: see L{MethodInjectionTestsMixin}
|
|
"""
|
|
headers = http_headers.Headers({b"Host": [b"twisted.invalid"]})
|
|
req = client.Request(
|
|
method=b"GET",
|
|
uri=b"http://twisted.invalid",
|
|
headers=headers,
|
|
bodyProducer=None,
|
|
)
|
|
req.method = method
|
|
req.writeTo(StringTransport())
|
|
|
|
|
|
class RequestURIInjectionTests(
|
|
URIInjectionTestsMixin,
|
|
SynchronousTestCase,
|
|
):
|
|
"""
|
|
Test L{client.Request} against HTTP URI injections.
|
|
"""
|
|
|
|
def attemptRequestWithMaliciousURI(self, uri):
|
|
"""
|
|
Attempt a request with the provided URI.
|
|
|
|
@param method: see L{URIInjectionTestsMixin}
|
|
"""
|
|
client.Request(
|
|
method=b"GET",
|
|
uri=uri,
|
|
headers=http_headers.Headers(),
|
|
bodyProducer=None,
|
|
)
|
|
|
|
|
|
class RequestWriteToURIInjectionTests(
|
|
URIInjectionTestsMixin,
|
|
SynchronousTestCase,
|
|
):
|
|
"""
|
|
Test L{client.Request.writeTo} against HTTP method injections.
|
|
"""
|
|
|
|
def attemptRequestWithMaliciousURI(self, uri):
|
|
"""
|
|
Attempt a request with the provided method.
|
|
|
|
@param method: see L{URIInjectionTestsMixin}
|
|
"""
|
|
headers = http_headers.Headers({b"Host": [b"twisted.invalid"]})
|
|
req = client.Request(
|
|
method=b"GET",
|
|
uri=b"http://twisted.invalid",
|
|
headers=headers,
|
|
bodyProducer=None,
|
|
)
|
|
req.uri = uri
|
|
req.writeTo(StringTransport())
|