466 lines
17 KiB
Python
466 lines
17 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import Dict
|
|
from typing import Optional
|
|
from typing import Type
|
|
|
|
import base64
|
|
from datetime import datetime
|
|
import hashlib
|
|
import hmac
|
|
import time
|
|
from hashlib import sha256
|
|
|
|
try:
|
|
import simplejson as json
|
|
except ImportError:
|
|
import json # type: ignore
|
|
|
|
from libcloud.utils.py3 import ET
|
|
from libcloud.utils.py3 import _real_unicode
|
|
from libcloud.utils.py3 import basestring
|
|
from libcloud.common.base import ConnectionUserAndKey, XmlResponse, BaseDriver
|
|
from libcloud.common.base import JsonResponse
|
|
from libcloud.common.types import InvalidCredsError, MalformedResponseError
|
|
from libcloud.utils.py3 import b, httplib, urlquote
|
|
from libcloud.utils.xml import findtext, findall
|
|
|
|
__all__ = [
|
|
'AWSBaseResponse',
|
|
'AWSGenericResponse',
|
|
|
|
'AWSTokenConnection',
|
|
'SignedAWSConnection',
|
|
|
|
'AWSRequestSignerAlgorithmV2',
|
|
'AWSRequestSignerAlgorithmV4',
|
|
|
|
'AWSDriver'
|
|
]
|
|
|
|
DEFAULT_SIGNATURE_VERSION = '2'
|
|
UNSIGNED_PAYLOAD = 'UNSIGNED-PAYLOAD'
|
|
|
|
PARAMS_NOT_STRING_ERROR_MSG = """
|
|
"params" dictionary contains an attribute "%s" which value (%s, %s) is not a
|
|
string.
|
|
|
|
Parameters are sent via query parameters and not via request body and as such,
|
|
all the values need to be of a simple type (string, int, bool).
|
|
|
|
For arrays and other complex types, you should use notation similar to this
|
|
one:
|
|
|
|
params['TagSpecification.1.Tag.Value'] = 'foo'
|
|
params['TagSpecification.2.Tag.Value'] = 'bar'
|
|
|
|
See https://docs.aws.amazon.com/AWSEC2/latest/APIReference/Query-Requests.html
|
|
for details.
|
|
""".strip()
|
|
|
|
|
|
class AWSBaseResponse(XmlResponse):
|
|
namespace = None
|
|
|
|
def _parse_error_details(self, element):
|
|
"""
|
|
Parse code and message from the provided error element.
|
|
|
|
:return: ``tuple`` with two elements: (code, message)
|
|
:rtype: ``tuple``
|
|
"""
|
|
code = findtext(element=element, xpath='Code',
|
|
namespace=self.namespace)
|
|
message = findtext(element=element, xpath='Message',
|
|
namespace=self.namespace)
|
|
|
|
return code, message
|
|
|
|
|
|
class AWSGenericResponse(AWSBaseResponse):
|
|
# There are multiple error messages in AWS, but they all have an Error node
|
|
# with Code and Message child nodes. Xpath to select them
|
|
# None if the root node *is* the Error node
|
|
xpath = None
|
|
|
|
# This dict maps <Error><Code>CodeName</Code></Error> to a specific
|
|
# exception class that is raised immediately.
|
|
# If a custom exception class is not defined, errors are accumulated and
|
|
# returned from the parse_error method.
|
|
exceptions = {} # type: Dict[str, Type[Exception]]
|
|
|
|
def success(self):
|
|
return self.status in [httplib.OK, httplib.CREATED, httplib.ACCEPTED]
|
|
|
|
def parse_error(self):
|
|
context = self.connection.context
|
|
status = int(self.status)
|
|
|
|
# FIXME: Probably ditch this as the forbidden message will have
|
|
# corresponding XML.
|
|
if status == httplib.FORBIDDEN:
|
|
if not self.body:
|
|
raise InvalidCredsError(str(self.status) + ': ' + self.error)
|
|
else:
|
|
raise InvalidCredsError(self.body)
|
|
|
|
try:
|
|
body = ET.XML(self.body)
|
|
except Exception:
|
|
raise MalformedResponseError('Failed to parse XML',
|
|
body=self.body,
|
|
driver=self.connection.driver)
|
|
|
|
if self.xpath:
|
|
errs = findall(element=body, xpath=self.xpath,
|
|
namespace=self.namespace)
|
|
else:
|
|
errs = [body]
|
|
|
|
msgs = []
|
|
for err in errs:
|
|
code, message = self._parse_error_details(element=err)
|
|
exceptionCls = self.exceptions.get(code, None)
|
|
|
|
if exceptionCls is None:
|
|
msgs.append('%s: %s' % (code, message))
|
|
continue
|
|
|
|
# Custom exception class is defined, immediately throw an exception
|
|
params = {}
|
|
if hasattr(exceptionCls, 'kwargs'):
|
|
for key in exceptionCls.kwargs:
|
|
if key in context:
|
|
params[key] = context[key]
|
|
|
|
raise exceptionCls(value=message, driver=self.connection.driver,
|
|
**params)
|
|
|
|
return "\n".join(msgs)
|
|
|
|
|
|
class AWSTokenConnection(ConnectionUserAndKey):
|
|
|
|
def __init__(self, user_id, key, secure=True,
|
|
host=None, port=None, url=None, timeout=None, proxy_url=None,
|
|
token=None, retry_delay=None, backoff=None):
|
|
self.token = token
|
|
super(AWSTokenConnection, self).__init__(user_id, key, secure=secure,
|
|
host=host, port=port, url=url,
|
|
timeout=timeout,
|
|
retry_delay=retry_delay,
|
|
backoff=backoff,
|
|
proxy_url=proxy_url)
|
|
|
|
def add_default_params(self, params):
|
|
# Even though we are adding it to the headers, we need it here too
|
|
# so that the token is added to the signature.
|
|
if self.token:
|
|
params['x-amz-security-token'] = self.token
|
|
return super(AWSTokenConnection, self).add_default_params(params)
|
|
|
|
def add_default_headers(self, headers):
|
|
if self.token:
|
|
headers['x-amz-security-token'] = self.token
|
|
return super(AWSTokenConnection, self).add_default_headers(headers)
|
|
|
|
|
|
class AWSRequestSigner(object):
|
|
"""
|
|
Class which handles signing the outgoing AWS requests.
|
|
"""
|
|
|
|
def __init__(self, access_key, access_secret, version, connection):
|
|
"""
|
|
:param access_key: Access key.
|
|
:type access_key: ``str``
|
|
|
|
:param access_secret: Access secret.
|
|
:type access_secret: ``str``
|
|
|
|
:param version: API version.
|
|
:type version: ``str``
|
|
|
|
:param connection: Connection instance.
|
|
:type connection: :class:`Connection`
|
|
"""
|
|
self.access_key = access_key
|
|
self.access_secret = access_secret
|
|
self.version = version
|
|
# TODO: Remove cycling dependency between connection and signer
|
|
self.connection = connection
|
|
|
|
def get_request_params(self, params, method='GET', path='/'):
|
|
return params
|
|
|
|
def get_request_headers(self, params, headers, method='GET', path='/',
|
|
data=None):
|
|
return params, headers
|
|
|
|
|
|
class AWSRequestSignerAlgorithmV2(AWSRequestSigner):
|
|
def get_request_params(self, params, method='GET', path='/'):
|
|
params['SignatureVersion'] = '2'
|
|
params['SignatureMethod'] = 'HmacSHA256'
|
|
params['AWSAccessKeyId'] = self.access_key
|
|
params['Version'] = self.version
|
|
params['Timestamp'] = time.strftime('%Y-%m-%dT%H:%M:%SZ',
|
|
time.gmtime())
|
|
params['Signature'] = self._get_aws_auth_param(
|
|
params=params,
|
|
secret_key=self.access_secret,
|
|
path=path)
|
|
return params
|
|
|
|
def _get_aws_auth_param(self, params, secret_key, path='/'):
|
|
"""
|
|
Creates the signature required for AWS, per
|
|
http://bit.ly/aR7GaQ [docs.amazonwebservices.com]:
|
|
|
|
StringToSign = HTTPVerb + "\n" +
|
|
ValueOfHostHeaderInLowercase + "\n" +
|
|
HTTPRequestURI + "\n" +
|
|
CanonicalizedQueryString <from the preceding step>
|
|
"""
|
|
connection = self.connection
|
|
|
|
keys = list(params.keys())
|
|
keys.sort()
|
|
pairs = []
|
|
for key in keys:
|
|
value = str(params[key])
|
|
pairs.append(urlquote(key, safe='') + '=' +
|
|
urlquote(value, safe='-_~'))
|
|
|
|
qs = '&'.join(pairs)
|
|
|
|
hostname = connection.host
|
|
if (connection.secure and connection.port != 443) or \
|
|
(not connection.secure and connection.port != 80):
|
|
hostname += ':' + str(connection.port)
|
|
|
|
string_to_sign = '\n'.join(('GET', hostname, path, qs))
|
|
|
|
b64_hmac = base64.b64encode(
|
|
hmac.new(b(secret_key), b(string_to_sign),
|
|
digestmod=sha256).digest()
|
|
)
|
|
|
|
return b64_hmac.decode('utf-8')
|
|
|
|
|
|
class AWSRequestSignerAlgorithmV4(AWSRequestSigner):
|
|
def get_request_params(self, params, method='GET', path='/'):
|
|
if method == 'GET':
|
|
params['Version'] = self.version
|
|
return params
|
|
|
|
def get_request_headers(self, params, headers, method='GET', path='/',
|
|
data=None):
|
|
now = datetime.utcnow()
|
|
headers['X-AMZ-Date'] = now.strftime('%Y%m%dT%H%M%SZ')
|
|
headers['X-AMZ-Content-SHA256'] = self._get_payload_hash(method, data)
|
|
headers['Authorization'] = \
|
|
self._get_authorization_v4_header(params=params, headers=headers,
|
|
dt=now, method=method, path=path,
|
|
data=data)
|
|
|
|
return params, headers
|
|
|
|
def _get_authorization_v4_header(self, params, headers, dt, method='GET',
|
|
path='/', data=None):
|
|
credentials_scope = self._get_credential_scope(dt=dt)
|
|
signed_headers = self._get_signed_headers(headers=headers)
|
|
signature = self._get_signature(params=params, headers=headers,
|
|
dt=dt, method=method, path=path,
|
|
data=data)
|
|
|
|
return 'AWS4-HMAC-SHA256 Credential=%(u)s/%(c)s, ' \
|
|
'SignedHeaders=%(sh)s, Signature=%(s)s' % {
|
|
'u': self.access_key,
|
|
'c': credentials_scope,
|
|
'sh': signed_headers,
|
|
's': signature
|
|
}
|
|
|
|
def _get_signature(self, params, headers, dt, method, path, data):
|
|
key = self._get_key_to_sign_with(dt)
|
|
string_to_sign = self._get_string_to_sign(params=params,
|
|
headers=headers, dt=dt,
|
|
method=method, path=path,
|
|
data=data)
|
|
return _sign(key=key, msg=string_to_sign, hex=True)
|
|
|
|
def _get_key_to_sign_with(self, dt):
|
|
return _sign(
|
|
_sign(
|
|
_sign(
|
|
_sign(('AWS4' + self.access_secret),
|
|
dt.strftime('%Y%m%d')),
|
|
self.connection.driver.region_name),
|
|
self.connection.service_name),
|
|
'aws4_request')
|
|
|
|
def _get_string_to_sign(self, params, headers, dt, method, path, data):
|
|
canonical_request = self._get_canonical_request(params=params,
|
|
headers=headers,
|
|
method=method,
|
|
path=path,
|
|
data=data)
|
|
|
|
return '\n'.join(['AWS4-HMAC-SHA256',
|
|
dt.strftime('%Y%m%dT%H%M%SZ'),
|
|
self._get_credential_scope(dt),
|
|
_hash(canonical_request)])
|
|
|
|
def _get_credential_scope(self, dt):
|
|
return '/'.join([dt.strftime('%Y%m%d'),
|
|
self.connection.driver.region_name,
|
|
self.connection.service_name,
|
|
'aws4_request'])
|
|
|
|
def _get_signed_headers(self, headers):
|
|
return ';'.join([k.lower() for k in sorted(headers.keys())])
|
|
|
|
def _get_canonical_headers(self, headers):
|
|
return '\n'.join([':'.join([k.lower(), str(v).strip()])
|
|
for k, v in sorted(headers.items())]) + '\n'
|
|
|
|
def _get_payload_hash(self, method, data=None):
|
|
if data is UnsignedPayloadSentinel:
|
|
return UNSIGNED_PAYLOAD
|
|
if method in ('POST', 'PUT'):
|
|
if data:
|
|
if hasattr(data, 'next') or hasattr(data, '__next__'):
|
|
# File upload; don't try to read the entire payload
|
|
return UNSIGNED_PAYLOAD
|
|
return _hash(data)
|
|
else:
|
|
return UNSIGNED_PAYLOAD
|
|
else:
|
|
return _hash('')
|
|
|
|
def _get_request_params(self, params):
|
|
# For self.method == GET
|
|
return '&'.join(["%s=%s" %
|
|
(urlquote(k, safe=''), urlquote(str(v), safe='~'))
|
|
for k, v in sorted(params.items())])
|
|
|
|
def _get_canonical_request(self, params, headers, method, path, data):
|
|
return '\n'.join([
|
|
method,
|
|
path,
|
|
self._get_request_params(params),
|
|
self._get_canonical_headers(headers),
|
|
self._get_signed_headers(headers),
|
|
self._get_payload_hash(method, data)
|
|
])
|
|
|
|
|
|
class UnsignedPayloadSentinel:
|
|
pass
|
|
|
|
|
|
class SignedAWSConnection(AWSTokenConnection):
|
|
version = None # type: Optional[str]
|
|
|
|
def __init__(self, user_id, key, secure=True, host=None, port=None,
|
|
url=None, timeout=None, proxy_url=None, token=None,
|
|
retry_delay=None, backoff=None,
|
|
signature_version=DEFAULT_SIGNATURE_VERSION):
|
|
super(SignedAWSConnection, self).__init__(user_id=user_id, key=key,
|
|
secure=secure, host=host,
|
|
port=port, url=url,
|
|
timeout=timeout, token=token,
|
|
retry_delay=retry_delay,
|
|
backoff=backoff,
|
|
proxy_url=proxy_url)
|
|
self.signature_version = str(signature_version)
|
|
|
|
if self.signature_version == '2':
|
|
signer_cls = AWSRequestSignerAlgorithmV2
|
|
elif self.signature_version == '4':
|
|
signer_cls = AWSRequestSignerAlgorithmV4
|
|
else:
|
|
raise ValueError('Unsupported signature_version: %s' %
|
|
(signature_version))
|
|
|
|
self.signer = signer_cls(access_key=self.user_id,
|
|
access_secret=self.key,
|
|
version=self.version,
|
|
connection=self)
|
|
|
|
def add_default_params(self, params):
|
|
params = self.signer.get_request_params(params=params,
|
|
method=self.method,
|
|
path=self.action)
|
|
|
|
# Verify that params only contain simple types and no nested
|
|
# dictionaries.
|
|
# params are sent via query params so only strings are supported
|
|
for key, value in params.items():
|
|
if not isinstance(value, (_real_unicode, basestring, int, bool)):
|
|
msg = PARAMS_NOT_STRING_ERROR_MSG % (key, value, type(value))
|
|
raise ValueError(msg)
|
|
|
|
return params
|
|
|
|
def pre_connect_hook(self, params, headers):
|
|
params, headers = self.signer.get_request_headers(params=params,
|
|
headers=headers,
|
|
method=self.method,
|
|
path=self.action,
|
|
data=self.data)
|
|
return params, headers
|
|
|
|
|
|
class AWSJsonResponse(JsonResponse):
|
|
"""
|
|
Amazon ECS response class.
|
|
ECS API uses JSON unlike the s3, elb drivers
|
|
"""
|
|
def parse_error(self):
|
|
response = json.loads(self.body)
|
|
code = response['__type']
|
|
message = response.get('Message', response['message'])
|
|
return ('%s: %s' % (code, message))
|
|
|
|
|
|
def _sign(key, msg, hex=False):
|
|
if hex:
|
|
return hmac.new(b(key), b(msg), hashlib.sha256).hexdigest()
|
|
else:
|
|
return hmac.new(b(key), b(msg), hashlib.sha256).digest()
|
|
|
|
|
|
def _hash(msg):
|
|
return hashlib.sha256(b(msg)).hexdigest()
|
|
|
|
|
|
class AWSDriver(BaseDriver):
|
|
def __init__(self, key, secret=None, secure=True, host=None, port=None,
|
|
api_version=None, region=None, token=None, **kwargs):
|
|
self.token = token
|
|
super(AWSDriver, self).__init__(key, secret=secret, secure=secure,
|
|
host=host, port=port,
|
|
api_version=api_version, region=region,
|
|
token=token, **kwargs)
|
|
|
|
def _ex_connection_class_kwargs(self):
|
|
kwargs = super(AWSDriver, self)._ex_connection_class_kwargs()
|
|
kwargs['token'] = self.token
|
|
return kwargs
|