Source code for spack.oci.opener

# Copyright Spack Project Developers. See COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

"""All the logic for OCI fetching and authentication"""

import base64
import json
import re
import urllib.error
import urllib.parse
import urllib.request
from enum import Enum, auto
from http.client import HTTPResponse
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple
from urllib.request import Request

import spack.config
import spack.llnl.util.lang
import spack.mirrors.mirror
import spack.tokenize
import spack.util.web

from .image import ImageReference


def _urlopen():
    opener = create_opener()

    def dispatch_open(fullurl, data=None, timeout=None):
        timeout = timeout or spack.config.get("config:connect_timeout", 10)
        return opener.open(fullurl, data, timeout)

    return dispatch_open


OpenType = Callable[..., HTTPResponse]
MaybeOpen = Optional[OpenType]

#: Opener that automatically uses OCI authentication based on mirror config
urlopen: OpenType = spack.llnl.util.lang.Singleton(_urlopen)


SP = r" "
OWS = r"[ \t]*"
BWS = OWS
HTAB = r"\t"
VCHAR = r"\x21-\x7E"
tchar = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]"
token = rf"{tchar}+"
obs_text = r"\x80-\xFF"
qdtext = rf"[{HTAB}{SP}\x21\x23-\x5B\x5D-\x7E{obs_text}]"
quoted_pair = rf"\\([{HTAB}{SP}{VCHAR}{obs_text}])"
quoted_string = rf'"(?:({qdtext}*)|{quoted_pair})*"'


[docs] class WwwAuthenticateTokens(spack.tokenize.TokenBase): AUTH_PARAM = rf"({token}){BWS}={BWS}({token}|{quoted_string})" # TOKEN68 = r"([A-Za-z0-9\-._~+/]+=*)" # todo... support this? TOKEN = rf"{tchar}+" EQUALS = rf"{BWS}={BWS}" COMMA = rf"{OWS},{OWS}" SPACE = r" +" EOF = r"$" ANY = r"."
WWW_AUTHENTICATE_TOKENIZER = spack.tokenize.Tokenizer(WwwAuthenticateTokens)
[docs] class State(Enum): CHALLENGE = auto() AUTH_PARAM_LIST_START = auto() AUTH_PARAM = auto() NEXT_IN_LIST = auto() AUTH_PARAM_OR_SCHEME = auto()
[docs] class Challenge: __slots__ = ["scheme", "params"] def __init__( self, scheme: Optional[str] = None, params: Optional[List[Tuple[str, str]]] = None ) -> None: self.scheme = scheme or "" self.params = params or [] def __repr__(self) -> str: return f"Challenge({self.scheme}, {self.params})" def __eq__(self, other: object) -> bool: return ( isinstance(other, Challenge) and self.scheme == other.scheme and self.params == other.params )
[docs] def matches_scheme(self, scheme: str) -> bool: """Checks whether the challenge matches the given scheme, case-insensitive.""" return self.scheme == scheme.lower()
[docs] def get_param(self, key: str) -> Optional[str]: """Get the value of an auth param by key, or None if not found.""" return next((v for k, v in self.params if k == key.lower()), None)
[docs] def parse_www_authenticate(input: str): """Very basic parsing of www-authenticate parsing (RFC7235 section 4.1) Notice: this omits token68 support.""" # auth-scheme = token # auth-param = token BWS "=" BWS ( token / quoted-string ) # challenge = auth-scheme [ 1*SP ( token68 / #auth-param ) ] # WWW-Authenticate = 1#challenge challenges: List[Challenge] = [] _unquote = re.compile(quoted_pair).sub unquote = lambda s: _unquote(r"\1", s[1:-1]) mode: State = State.CHALLENGE tokens = WWW_AUTHENTICATE_TOKENIZER.tokenize(input) current_challenge = Challenge() def extract_auth_param(input: str) -> Tuple[str, str]: key, value = input.split("=", 1) key = key.rstrip().lower() value = value.lstrip() if value.startswith('"'): value = unquote(value) return key, value while True: token: spack.tokenize.Token = next(tokens) if mode == State.CHALLENGE: if token.kind == WwwAuthenticateTokens.EOF: raise ValueError(token) elif token.kind == WwwAuthenticateTokens.TOKEN: current_challenge.scheme = token.value.lower() mode = State.AUTH_PARAM_LIST_START else: raise ValueError(token) elif mode == State.AUTH_PARAM_LIST_START: if token.kind == WwwAuthenticateTokens.EOF: challenges.append(current_challenge) break elif token.kind == WwwAuthenticateTokens.COMMA: # Challenge without param list, followed by another challenge. challenges.append(current_challenge) current_challenge = Challenge() mode = State.CHALLENGE elif token.kind == WwwAuthenticateTokens.SPACE: # A space means it must be followed by param list mode = State.AUTH_PARAM else: raise ValueError(token) elif mode == State.AUTH_PARAM: if token.kind == WwwAuthenticateTokens.EOF: raise ValueError(token) elif token.kind == WwwAuthenticateTokens.AUTH_PARAM: key, value = extract_auth_param(token.value) current_challenge.params.append((key, value)) mode = State.NEXT_IN_LIST else: raise ValueError(token) elif mode == State.NEXT_IN_LIST: if token.kind == WwwAuthenticateTokens.EOF: challenges.append(current_challenge) break elif token.kind == WwwAuthenticateTokens.COMMA: mode = State.AUTH_PARAM_OR_SCHEME else: raise ValueError(token) elif mode == State.AUTH_PARAM_OR_SCHEME: if token.kind == WwwAuthenticateTokens.EOF: raise ValueError(token) elif token.kind == WwwAuthenticateTokens.TOKEN: challenges.append(current_challenge) current_challenge = Challenge(token.value.lower()) mode = State.AUTH_PARAM_LIST_START elif token.kind == WwwAuthenticateTokens.AUTH_PARAM: key, value = extract_auth_param(token.value) current_challenge.params.append((key, value)) mode = State.NEXT_IN_LIST return challenges
[docs] class RealmServiceScope(NamedTuple): realm: str service: str scope: str
[docs] class UsernamePassword(NamedTuple): username: str password: str @property def basic_auth_header(self) -> str: encoded = base64.b64encode(f"{self.username}:{self.password}".encode("utf-8")).decode( "utf-8" ) return f"Basic {encoded}"
def _get_bearer_challenge(challenges: List[Challenge]) -> Optional[RealmServiceScope]: """Return the realm/service/scope for a Bearer auth challenge, or None if not found.""" challenge = next((c for c in challenges if c.matches_scheme("Bearer")), None) if challenge is None: return None # Get realm / service / scope from challenge realm = challenge.get_param("realm") service = challenge.get_param("service") scope = challenge.get_param("scope") if realm is None or service is None or scope is None: return None return RealmServiceScope(realm, service, scope) def _get_basic_challenge(challenges: List[Challenge]) -> Optional[str]: """Return the realm for a Basic auth challenge, or None if not found.""" challenge = next((c for c in challenges if c.matches_scheme("Basic")), None) if challenge is None: return None return challenge.get_param("realm")
[docs] class OCIAuthHandler(urllib.request.BaseHandler): def __init__(self, credentials_provider: Callable[[str], Optional[UsernamePassword]]): """ Args: credentials_provider: A function that takes a domain and may return a UsernamePassword. """ self.credentials_provider = credentials_provider # Cached authorization headers for a given domain. self.cached_auth_headers: Dict[str, str] = {}
[docs] def https_request(self, req: Request): # Eagerly add the bearer token to the request if no # auth header is set yet, to avoid 401s in multiple # requests to the same registry. # Use has_header, not .headers, since there are two # types of headers (redirected and unredirected) if req.has_header("Authorization"): return req parsed = urllib.parse.urlparse(req.full_url) auth_header = self.cached_auth_headers.get(parsed.netloc) if not auth_header: return req req.add_unredirected_header("Authorization", auth_header) return req
def _try_bearer_challenge( self, challenges: List[Challenge], credentials: Optional[UsernamePassword], timeout: Optional[float], ) -> Optional[str]: # Check whether a Bearer challenge is present in the WWW-Authenticate header challenge = _get_bearer_challenge(challenges) if not challenge: return None # Get the token from the auth handler query = urllib.parse.urlencode( {"service": challenge.service, "scope": challenge.scope, "client_id": "spack"} ) parsed = urllib.parse.urlparse(challenge.realm)._replace( query=query, fragment="", params="" ) # Don't send credentials over insecure transport. if parsed.scheme != "https": raise ValueError(f"Cannot login over insecure {parsed.scheme} connection") request = Request(urllib.parse.urlunparse(parsed), method="GET") if credentials is not None: request.add_unredirected_header("Authorization", credentials.basic_auth_header) # Do a GET request. response = self.parent.open(request, timeout=timeout) try: response_json = json.load(response) token = response_json.get("token") if token is None: token = response_json.get("access_token") assert type(token) is str except Exception as e: raise ValueError(f"Malformed token response from {challenge.realm}") from e return f"Bearer {token}" def _try_basic_challenge( self, challenges: List[Challenge], credentials: UsernamePassword ) -> Optional[str]: # Check whether a Basic challenge is present in the WWW-Authenticate header # A realm is required for Basic auth, although we don't use it here. Leave this as a # validation step. realm = _get_basic_challenge(challenges) if not realm: return None return credentials.basic_auth_header
[docs] def http_error_401(self, req: Request, fp, code, msg, headers): # Login failed, avoid infinite recursion where we go back and # forth between auth server and registry if hasattr(req, "login_attempted"): raise spack.util.web.DetailedHTTPError( req, code, f"Failed to login: {msg}", headers, fp ) # On 401 Unauthorized, parse the WWW-Authenticate header # to determine what authentication is required if "WWW-Authenticate" not in headers: raise spack.util.web.DetailedHTTPError( req, code, "Cannot login to registry, missing WWW-Authenticate header", headers, fp ) www_auth_str = headers["WWW-Authenticate"] try: challenges = parse_www_authenticate(www_auth_str) except ValueError as e: raise spack.util.web.DetailedHTTPError( req, code, f"Cannot login to registry, malformed WWW-Authenticate header: {www_auth_str}", headers, fp, ) from e registry = urllib.parse.urlparse(req.get_full_url()).netloc credentials = self.credentials_provider(registry) # First try Bearer, then Basic try: auth_header = self._try_bearer_challenge(challenges, credentials, req.timeout) if not auth_header and credentials: auth_header = self._try_basic_challenge(challenges, credentials) except Exception as e: raise spack.util.web.DetailedHTTPError( req, code, f"Cannot login to registry: {e}", headers, fp ) from e if not auth_header: raise spack.util.web.DetailedHTTPError( req, code, f"Cannot login to registry, unsupported authentication scheme: {www_auth_str}", headers, fp, ) self.cached_auth_headers[registry] = auth_header # Add the authorization header to the request req.add_unredirected_header("Authorization", auth_header) setattr(req, "login_attempted", True) return self.parent.open(req, timeout=req.timeout)
[docs] def credentials_from_mirrors( domain: str, *, mirrors: Optional[Iterable[spack.mirrors.mirror.Mirror]] = None ) -> Optional[UsernamePassword]: """Filter out OCI registry credentials from a list of mirrors.""" mirrors = mirrors or spack.mirrors.mirror.MirrorCollection().values() for mirror in mirrors: # Prefer push credentials over fetch. Unlikely that those are different # but our config format allows it. for direction in ("push", "fetch"): pair = mirror.get_credentials(direction).get("access_pair") if not pair: continue url = mirror.get_url(direction) try: parsed = ImageReference.from_url(url) except ValueError: continue if parsed.domain == domain: return UsernamePassword(*pair) return None
[docs] def create_opener(): """Create an opener that can handle OCI authentication.""" opener = urllib.request.OpenerDirector() for handler in [ urllib.request.ProxyHandler(), urllib.request.UnknownHandler(), urllib.request.HTTPHandler(), spack.util.web.SpackHTTPSHandler(context=spack.util.web.ssl_create_default_context()), spack.util.web.SpackHTTPDefaultErrorHandler(), urllib.request.HTTPRedirectHandler(), urllib.request.HTTPErrorProcessor(), OCIAuthHandler(credentials_from_mirrors), ]: opener.add_handler(handler) return opener
[docs] def ensure_status(request: urllib.request.Request, response: HTTPResponse, status: int): """Raise an error if the response status is not the expected one.""" if response.status == status: return raise spack.util.web.DetailedHTTPError( request, response.status, response.reason, response.info(), None )
default_retry = spack.util.web.retry_on_transient_error