You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

273 lines
9.3 KiB

4 days ago
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for Agent Identity credentials."""
import base64
import hashlib
import logging
import os
import re
import time
from urllib.parse import quote, urlparse
from google.auth import environment_vars
from google.auth import exceptions
_LOGGER = logging.getLogger(__name__)
CRYPTOGRAPHY_NOT_FOUND_ERROR = (
"The cryptography library is required for certificate-based authentication."
"Please install it with `pip install google-auth[cryptography]`."
)
# SPIFFE trust domain patterns for Agent Identities.
_AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS = [
r"^agents\.global\.org-\d+\.system\.id\.goog$",
r"^agents\.global\.proj-\d+\.system\.id\.goog$",
]
_WELL_KNOWN_CERT_PATH = "/var/run/secrets/workload-spiffe-credentials/certificates.pem"
# Constants for polling the certificate file.
_FAST_POLL_CYCLES = 50
_FAST_POLL_INTERVAL = 0.1 # 100ms
_SLOW_POLL_INTERVAL = 0.5 # 500ms
_TOTAL_TIMEOUT = 30 # seconds
# Calculate the number of slow poll cycles based on the total timeout.
_SLOW_POLL_CYCLES = int(
(_TOTAL_TIMEOUT - (_FAST_POLL_CYCLES * _FAST_POLL_INTERVAL)) / _SLOW_POLL_INTERVAL
)
_POLLING_INTERVALS = ([_FAST_POLL_INTERVAL] * _FAST_POLL_CYCLES) + (
[_SLOW_POLL_INTERVAL] * _SLOW_POLL_CYCLES
)
def _is_certificate_file_ready(path):
"""Checks if a file exists and is not empty."""
return path and os.path.exists(path) and os.path.getsize(path) > 0
def get_agent_identity_certificate_path():
"""Gets the certificate path from the certificate config file.
The path to the certificate config file is read from the
GOOGLE_API_CERTIFICATE_CONFIG environment variable. This function
implements a retry mechanism to handle cases where the environment
variable is set before the files are available on the filesystem.
Returns:
str: The path to the leaf certificate file.
Raises:
google.auth.exceptions.RefreshError: If the certificate config file
or the certificate file cannot be found after retries.
"""
import json
cert_config_path = os.environ.get(environment_vars.GOOGLE_API_CERTIFICATE_CONFIG)
if not cert_config_path:
return None
has_logged_warning = False
for interval in _POLLING_INTERVALS:
try:
with open(cert_config_path, "r") as f:
cert_config = json.load(f)
cert_path = (
cert_config.get("cert_configs", {})
.get("workload", {})
.get("cert_path")
)
if _is_certificate_file_ready(cert_path):
return cert_path
except (IOError, ValueError, KeyError):
if not has_logged_warning:
_LOGGER.warning(
"Certificate config file not found at %s (from %s environment "
"variable). Retrying for up to %s seconds.",
cert_config_path,
environment_vars.GOOGLE_API_CERTIFICATE_CONFIG,
_TOTAL_TIMEOUT,
)
has_logged_warning = True
pass
# As a fallback, check the well-known certificate path.
if _is_certificate_file_ready(_WELL_KNOWN_CERT_PATH):
return _WELL_KNOWN_CERT_PATH
# A sleep is required in two cases:
# 1. The config file is not found (the except block).
# 2. The config file is found, but the certificate is not yet available.
# In both cases, we need to poll, so we sleep on every iteration
# that doesn't return a certificate.
time.sleep(interval)
raise exceptions.RefreshError(
"Certificate config or certificate file not found after multiple retries. "
f"Token binding protection is failing. You can turn off this protection by setting "
f"{environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES} to false "
"to fall back to unbound tokens."
)
def get_and_parse_agent_identity_certificate():
"""Gets and parses the agent identity certificate if not opted out.
Checks if the user has opted out of certificate-bound tokens. If not,
it gets the certificate path, reads the file, and parses it.
Returns:
The parsed certificate object if found and not opted out, otherwise None.
"""
# If the user has opted out of cert bound tokens, there is no need to
# look up the certificate.
is_opted_out = (
os.environ.get(
environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES,
"true",
).lower()
== "false"
)
if is_opted_out:
return None
cert_path = get_agent_identity_certificate_path()
if not cert_path:
return None
with open(cert_path, "rb") as cert_file:
cert_bytes = cert_file.read()
return parse_certificate(cert_bytes)
def parse_certificate(cert_bytes):
"""Parses a PEM-encoded certificate.
Args:
cert_bytes (bytes): The PEM-encoded certificate bytes.
Returns:
cryptography.x509.Certificate: The parsed certificate object.
"""
try:
from cryptography import x509
return x509.load_pem_x509_certificate(cert_bytes)
except ImportError as e:
raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e
def _is_agent_identity_certificate(cert):
"""Checks if a certificate is an Agent Identity certificate.
This is determined by checking the Subject Alternative Name (SAN) for a
SPIFFE ID with a trust domain matching Agent Identity patterns.
Args:
cert (cryptography.x509.Certificate): The parsed certificate object.
Returns:
bool: True if the certificate is an Agent Identity certificate,
False otherwise.
"""
try:
from cryptography import x509
from cryptography.x509.oid import ExtensionOID
try:
ext = cert.extensions.get_extension_for_oid(
ExtensionOID.SUBJECT_ALTERNATIVE_NAME
)
except x509.ExtensionNotFound:
return False
uris = ext.value.get_values_for_type(x509.UniformResourceIdentifier)
for uri in uris:
parsed_uri = urlparse(uri)
if parsed_uri.scheme == "spiffe":
trust_domain = parsed_uri.netloc
for pattern in _AGENT_IDENTITY_SPIFFE_TRUST_DOMAIN_PATTERNS:
if re.match(pattern, trust_domain):
return True
return False
except ImportError as e:
raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e
def calculate_certificate_fingerprint(cert):
"""Calculates the URL-encoded, unpadded, base64-encoded SHA256 hash of a
DER-encoded certificate.
Args:
cert (cryptography.x509.Certificate): The parsed certificate object.
Returns:
str: The URL-encoded, unpadded, base64-encoded SHA256 fingerprint.
"""
try:
from cryptography.hazmat.primitives import serialization
der_cert = cert.public_bytes(serialization.Encoding.DER)
fingerprint = hashlib.sha256(der_cert).digest()
# The certificate fingerprint is generated in two steps to align with GFE's
# expectations and ensure proper URL transmission:
# 1. Standard base64 encoding is applied, and padding ('=') is removed.
# 2. The resulting string is then URL-encoded to handle special characters
# ('+', '/') that would otherwise be misinterpreted in URL parameters.
base64_fingerprint = base64.b64encode(fingerprint).decode("utf-8")
unpadded_base64_fingerprint = base64_fingerprint.rstrip("=")
return quote(unpadded_base64_fingerprint)
except ImportError as e:
raise ImportError(CRYPTOGRAPHY_NOT_FOUND_ERROR) from e
def should_request_bound_token(cert):
"""Determines if a bound token should be requested.
This is based on the GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES
environment variable and whether the certificate is an agent identity cert.
Args:
cert (cryptography.x509.Certificate): The parsed certificate object.
Returns:
bool: True if a bound token should be requested, False otherwise.
"""
is_agent_cert = _is_agent_identity_certificate(cert)
is_opted_in = (
os.environ.get(
environment_vars.GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES,
"true",
).lower()
== "true"
)
return is_agent_cert and is_opted_in
def get_cached_cert_fingerprint(cached_cert):
"""Returns the fingerprint of the cached certificate."""
if cached_cert:
cert_obj = parse_certificate(cached_cert)
cached_cert_fingerprint = calculate_certificate_fingerprint(cert_obj)
else:
raise ValueError("mTLS connection is not configured.")
return cached_cert_fingerprint