# Copyright Spack Project Developers. See COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""All the logic for OCI fetching and authentication"""
import base64
import json
import re
import urllib.error
import urllib.parse
import urllib.request
from enum import Enum, auto
from http.client import HTTPResponse
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple
from urllib.request import Request
import spack.config
import spack.llnl.util.lang
import spack.mirrors.mirror
import spack.tokenize
import spack.util.web
from .image import ImageReference
def _urlopen():
opener = create_opener()
def dispatch_open(fullurl, data=None, timeout=None):
timeout = timeout or spack.config.get("config:connect_timeout", 10)
return opener.open(fullurl, data, timeout)
return dispatch_open
OpenType = Callable[..., HTTPResponse]
MaybeOpen = Optional[OpenType]
#: Opener that automatically uses OCI authentication based on mirror config
urlopen: OpenType = spack.llnl.util.lang.Singleton(_urlopen)
SP = r" "
OWS = r"[ \t]*"
BWS = OWS
HTAB = r"\t"
VCHAR = r"\x21-\x7E"
tchar = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]"
token = rf"{tchar}+"
obs_text = r"\x80-\xFF"
qdtext = rf"[{HTAB}{SP}\x21\x23-\x5B\x5D-\x7E{obs_text}]"
quoted_pair = rf"\\([{HTAB}{SP}{VCHAR}{obs_text}])"
quoted_string = rf'"(?:({qdtext}*)|{quoted_pair})*"'
[docs]
class WwwAuthenticateTokens(spack.tokenize.TokenBase):
AUTH_PARAM = rf"({token}){BWS}={BWS}({token}|{quoted_string})"
# TOKEN68 = r"([A-Za-z0-9\-._~+/]+=*)" # todo... support this?
TOKEN = rf"{tchar}+"
EQUALS = rf"{BWS}={BWS}"
COMMA = rf"{OWS},{OWS}"
SPACE = r" +"
EOF = r"$"
ANY = r"."
WWW_AUTHENTICATE_TOKENIZER = spack.tokenize.Tokenizer(WwwAuthenticateTokens)
[docs]
class State(Enum):
CHALLENGE = auto()
AUTH_PARAM_LIST_START = auto()
AUTH_PARAM = auto()
NEXT_IN_LIST = auto()
AUTH_PARAM_OR_SCHEME = auto()
[docs]
class Challenge:
__slots__ = ["scheme", "params"]
def __init__(
self, scheme: Optional[str] = None, params: Optional[List[Tuple[str, str]]] = None
) -> None:
self.scheme = scheme or ""
self.params = params or []
def __repr__(self) -> str:
return f"Challenge({self.scheme}, {self.params})"
def __eq__(self, other: object) -> bool:
return (
isinstance(other, Challenge)
and self.scheme == other.scheme
and self.params == other.params
)
[docs]
def matches_scheme(self, scheme: str) -> bool:
"""Checks whether the challenge matches the given scheme, case-insensitive."""
return self.scheme == scheme.lower()
[docs]
def get_param(self, key: str) -> Optional[str]:
"""Get the value of an auth param by key, or None if not found."""
return next((v for k, v in self.params if k == key.lower()), None)
[docs]
def parse_www_authenticate(input: str):
"""Very basic parsing of www-authenticate parsing (RFC7235 section 4.1)
Notice: this omits token68 support."""
# auth-scheme = token
# auth-param = token BWS "=" BWS ( token / quoted-string )
# challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ]
# WWW-Authenticate = 1#challenge
challenges: List[Challenge] = []
_unquote = re.compile(quoted_pair).sub
unquote = lambda s: _unquote(r"\1", s[1:-1])
mode: State = State.CHALLENGE
tokens = WWW_AUTHENTICATE_TOKENIZER.tokenize(input)
current_challenge = Challenge()
def extract_auth_param(input: str) -> Tuple[str, str]:
key, value = input.split("=", 1)
key = key.rstrip().lower()
value = value.lstrip()
if value.startswith('"'):
value = unquote(value)
return key, value
while True:
token: spack.tokenize.Token = next(tokens)
if mode == State.CHALLENGE:
if token.kind == WwwAuthenticateTokens.EOF:
raise ValueError(token)
elif token.kind == WwwAuthenticateTokens.TOKEN:
current_challenge.scheme = token.value.lower()
mode = State.AUTH_PARAM_LIST_START
else:
raise ValueError(token)
elif mode == State.AUTH_PARAM_LIST_START:
if token.kind == WwwAuthenticateTokens.EOF:
challenges.append(current_challenge)
break
elif token.kind == WwwAuthenticateTokens.COMMA:
# Challenge without param list, followed by another challenge.
challenges.append(current_challenge)
current_challenge = Challenge()
mode = State.CHALLENGE
elif token.kind == WwwAuthenticateTokens.SPACE:
# A space means it must be followed by param list
mode = State.AUTH_PARAM
else:
raise ValueError(token)
elif mode == State.AUTH_PARAM:
if token.kind == WwwAuthenticateTokens.EOF:
raise ValueError(token)
elif token.kind == WwwAuthenticateTokens.AUTH_PARAM:
key, value = extract_auth_param(token.value)
current_challenge.params.append((key, value))
mode = State.NEXT_IN_LIST
else:
raise ValueError(token)
elif mode == State.NEXT_IN_LIST:
if token.kind == WwwAuthenticateTokens.EOF:
challenges.append(current_challenge)
break
elif token.kind == WwwAuthenticateTokens.COMMA:
mode = State.AUTH_PARAM_OR_SCHEME
else:
raise ValueError(token)
elif mode == State.AUTH_PARAM_OR_SCHEME:
if token.kind == WwwAuthenticateTokens.EOF:
raise ValueError(token)
elif token.kind == WwwAuthenticateTokens.TOKEN:
challenges.append(current_challenge)
current_challenge = Challenge(token.value.lower())
mode = State.AUTH_PARAM_LIST_START
elif token.kind == WwwAuthenticateTokens.AUTH_PARAM:
key, value = extract_auth_param(token.value)
current_challenge.params.append((key, value))
mode = State.NEXT_IN_LIST
return challenges
[docs]
class RealmServiceScope(NamedTuple):
realm: str
service: str
scope: str
[docs]
class UsernamePassword(NamedTuple):
username: str
password: str
@property
def basic_auth_header(self) -> str:
encoded = base64.b64encode(f"{self.username}:{self.password}".encode("utf-8")).decode(
"utf-8"
)
return f"Basic {encoded}"
def _get_bearer_challenge(challenges: List[Challenge]) -> Optional[RealmServiceScope]:
"""Return the realm/service/scope for a Bearer auth challenge, or None if not found."""
challenge = next((c for c in challenges if c.matches_scheme("Bearer")), None)
if challenge is None:
return None
# Get realm / service / scope from challenge
realm = challenge.get_param("realm")
service = challenge.get_param("service")
scope = challenge.get_param("scope")
if realm is None or service is None or scope is None:
return None
return RealmServiceScope(realm, service, scope)
def _get_basic_challenge(challenges: List[Challenge]) -> Optional[str]:
"""Return the realm for a Basic auth challenge, or None if not found."""
challenge = next((c for c in challenges if c.matches_scheme("Basic")), None)
if challenge is None:
return None
return challenge.get_param("realm")
[docs]
class OCIAuthHandler(urllib.request.BaseHandler):
def __init__(self, credentials_provider: Callable[[str], Optional[UsernamePassword]]):
"""
Args:
credentials_provider: A function that takes a domain and may return a UsernamePassword.
"""
self.credentials_provider = credentials_provider
# Cached authorization headers for a given domain.
self.cached_auth_headers: Dict[str, str] = {}
[docs]
def https_request(self, req: Request):
# Eagerly add the bearer token to the request if no
# auth header is set yet, to avoid 401s in multiple
# requests to the same registry.
# Use has_header, not .headers, since there are two
# types of headers (redirected and unredirected)
if req.has_header("Authorization"):
return req
parsed = urllib.parse.urlparse(req.full_url)
auth_header = self.cached_auth_headers.get(parsed.netloc)
if not auth_header:
return req
req.add_unredirected_header("Authorization", auth_header)
return req
def _try_bearer_challenge(
self,
challenges: List[Challenge],
credentials: Optional[UsernamePassword],
timeout: Optional[float],
) -> Optional[str]:
# Check whether a Bearer challenge is present in the WWW-Authenticate header
challenge = _get_bearer_challenge(challenges)
if not challenge:
return None
# Get the token from the auth handler
query = urllib.parse.urlencode(
{"service": challenge.service, "scope": challenge.scope, "client_id": "spack"}
)
parsed = urllib.parse.urlparse(challenge.realm)._replace(
query=query, fragment="", params=""
)
# Don't send credentials over insecure transport.
if parsed.scheme != "https":
raise ValueError(f"Cannot login over insecure {parsed.scheme} connection")
request = Request(urllib.parse.urlunparse(parsed), method="GET")
if credentials is not None:
request.add_unredirected_header("Authorization", credentials.basic_auth_header)
# Do a GET request.
response = self.parent.open(request, timeout=timeout)
try:
response_json = json.load(response)
token = response_json.get("token")
if token is None:
token = response_json.get("access_token")
assert type(token) is str
except Exception as e:
raise ValueError(f"Malformed token response from {challenge.realm}") from e
return f"Bearer {token}"
def _try_basic_challenge(
self, challenges: List[Challenge], credentials: UsernamePassword
) -> Optional[str]:
# Check whether a Basic challenge is present in the WWW-Authenticate header
# A realm is required for Basic auth, although we don't use it here. Leave this as a
# validation step.
realm = _get_basic_challenge(challenges)
if not realm:
return None
return credentials.basic_auth_header
[docs]
def http_error_401(self, req: Request, fp, code, msg, headers):
# Login failed, avoid infinite recursion where we go back and
# forth between auth server and registry
if hasattr(req, "login_attempted"):
raise spack.util.web.DetailedHTTPError(
req, code, f"Failed to login: {msg}", headers, fp
)
# On 401 Unauthorized, parse the WWW-Authenticate header
# to determine what authentication is required
if "WWW-Authenticate" not in headers:
raise spack.util.web.DetailedHTTPError(
req, code, "Cannot login to registry, missing WWW-Authenticate header", headers, fp
)
www_auth_str = headers["WWW-Authenticate"]
try:
challenges = parse_www_authenticate(www_auth_str)
except ValueError as e:
raise spack.util.web.DetailedHTTPError(
req,
code,
f"Cannot login to registry, malformed WWW-Authenticate header: {www_auth_str}",
headers,
fp,
) from e
registry = urllib.parse.urlparse(req.get_full_url()).netloc
credentials = self.credentials_provider(registry)
# First try Bearer, then Basic
try:
auth_header = self._try_bearer_challenge(challenges, credentials, req.timeout)
if not auth_header and credentials:
auth_header = self._try_basic_challenge(challenges, credentials)
except Exception as e:
raise spack.util.web.DetailedHTTPError(
req, code, f"Cannot login to registry: {e}", headers, fp
) from e
if not auth_header:
raise spack.util.web.DetailedHTTPError(
req,
code,
f"Cannot login to registry, unsupported authentication scheme: {www_auth_str}",
headers,
fp,
)
self.cached_auth_headers[registry] = auth_header
# Add the authorization header to the request
req.add_unredirected_header("Authorization", auth_header)
setattr(req, "login_attempted", True)
return self.parent.open(req, timeout=req.timeout)
[docs]
def credentials_from_mirrors(
domain: str, *, mirrors: Optional[Iterable[spack.mirrors.mirror.Mirror]] = None
) -> Optional[UsernamePassword]:
"""Filter out OCI registry credentials from a list of mirrors."""
mirrors = mirrors or spack.mirrors.mirror.MirrorCollection().values()
for mirror in mirrors:
# Prefer push credentials over fetch. Unlikely that those are different
# but our config format allows it.
for direction in ("push", "fetch"):
pair = mirror.get_credentials(direction).get("access_pair")
if not pair:
continue
url = mirror.get_url(direction)
try:
parsed = ImageReference.from_url(url)
except ValueError:
continue
if parsed.domain == domain:
return UsernamePassword(*pair)
return None
[docs]
def create_opener():
"""Create an opener that can handle OCI authentication."""
opener = urllib.request.OpenerDirector()
for handler in [
urllib.request.ProxyHandler(),
urllib.request.UnknownHandler(),
urllib.request.HTTPHandler(),
spack.util.web.SpackHTTPSHandler(context=spack.util.web.ssl_create_default_context()),
spack.util.web.SpackHTTPDefaultErrorHandler(),
urllib.request.HTTPRedirectHandler(),
urllib.request.HTTPErrorProcessor(),
OCIAuthHandler(credentials_from_mirrors),
]:
opener.add_handler(handler)
return opener
[docs]
def ensure_status(request: urllib.request.Request, response: HTTPResponse, status: int):
"""Raise an error if the response status is not the expected one."""
if response.status == status:
return
raise spack.util.web.DetailedHTTPError(
request, response.status, response.reason, response.info(), None
)
default_retry = spack.util.web.retry_on_transient_error