# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from selenium.webdriver.common.bidi.common import command_builder
from selenium.webdriver.remote.websocket_connection import WebSocketConnection
[docs]
class NetworkEvent:
"""Represents a network event."""
def __init__(self, event_class: str, **kwargs: Any) -> None:
self.event_class = event_class
self.params = kwargs
[docs]
@classmethod
def from_json(cls, json: dict[str, Any]) -> NetworkEvent:
return cls(event_class=json.get("event_class", ""), **json)
[docs]
class Network:
EVENTS = {
"before_request": "network.beforeRequestSent",
"response_started": "network.responseStarted",
"response_completed": "network.responseCompleted",
"auth_required": "network.authRequired",
"fetch_error": "network.fetchError",
"continue_request": "network.continueRequest",
"continue_auth": "network.continueWithAuth",
}
PHASES = {
"before_request": "beforeRequestSent",
"response_started": "responseStarted",
"auth_required": "authRequired",
}
def __init__(self, conn: WebSocketConnection) -> None:
self.conn = conn
self.intercepts: list[str] = []
self.callbacks: dict[str | int, Any] = {}
self.subscriptions: dict[str, list[int]] = {}
def _add_intercept(
self,
phases: list[str] | None = None,
contexts: list[str] | None = None,
url_patterns: list[Any] | None = None,
) -> dict[str, Any]:
"""Add an intercept to the network.
Args:
phases: A list of phases to intercept. Default is None (empty list).
contexts: A list of contexts to intercept. Default is None.
url_patterns: A list of URL patterns to intercept. Default is None.
Returns:
str: intercept id
"""
if phases is None:
phases = []
params = {}
if contexts is not None:
params["contexts"] = contexts
if url_patterns is not None:
params["urlPatterns"] = url_patterns
if len(phases) > 0:
params["phases"] = phases
else:
params["phases"] = ["beforeRequestSent"]
cmd = command_builder("network.addIntercept", params)
result: dict[str, Any] = self.conn.execute(cmd)
self.intercepts.append(result["intercept"])
return result
def _remove_intercept(self, intercept: str | None = None) -> None:
"""Remove a specific intercept, or all intercepts.
Args:
intercept: The intercept to remove. Default is None.
Raises:
ValueError: If intercept is not found.
Note:
If intercept is None, all intercepts will be removed.
"""
if intercept is None:
intercepts_to_remove = self.intercepts.copy() # create a copy before iterating
for intercept_id in intercepts_to_remove: # remove all intercepts
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept_id}))
self.intercepts.remove(intercept_id)
else:
try:
self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept}))
self.intercepts.remove(intercept)
except Exception as e:
raise Exception(f"Exception: {e}")
def _on_request(self, event_name: str, callback: Callable[[Request], Any]) -> int:
"""Set a callback function to subscribe to a network event.
Args:
event_name: The event to subscribe to.
callback: The callback function to execute on event.
Takes Request object as argument.
Returns:
int: callback id
"""
event = NetworkEvent(event_name)
def _callback(event_data: NetworkEvent) -> None:
request = Request(
network=self,
request_id=event_data.params["request"].get("request", None),
body_size=event_data.params["request"].get("bodySize", None),
cookies=event_data.params["request"].get("cookies", None),
resource_type=event_data.params["request"].get("goog:resourceType", None),
headers=event_data.params["request"].get("headers", None),
headers_size=event_data.params["request"].get("headersSize", None),
timings=event_data.params["request"].get("timings", None),
url=event_data.params["request"].get("url", None),
)
callback(request)
callback_id: int = self.conn.add_callback(event, _callback)
if event_name in self.callbacks:
self.callbacks[event_name].append(callback_id)
else:
self.callbacks[event_name] = [callback_id]
return callback_id
[docs]
def add_request_handler(
self,
event: str,
callback: Callable[[Request], Any],
url_patterns: list[Any] | None = None,
contexts: list[str] | None = None,
) -> int:
"""Add a request handler to the network.
Args:
event: The event to subscribe to.
callback: The callback function to execute on request interception.
Takes Request object as argument.
url_patterns: A list of URL patterns to intercept. Default is None.
contexts: A list of contexts to intercept. Default is None.
Returns:
int: callback id
"""
try:
event_name = self.EVENTS[event]
phase_name = self.PHASES[event]
except KeyError:
raise Exception(f"Event {event} not found")
result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts)
callback_id = self._on_request(event_name, callback)
if event_name in self.subscriptions:
self.subscriptions[event_name].append(callback_id)
else:
params: dict[str, Any] = {}
params["events"] = [event_name]
self.conn.execute(command_builder("session.subscribe", params))
self.subscriptions[event_name] = [callback_id]
self.callbacks[callback_id] = result["intercept"]
return callback_id
[docs]
def remove_request_handler(self, event: str, callback_id: int) -> None:
"""Remove a request handler from the network.
Args:
event: The event to unsubscribe from.
callback_id: The callback id to remove.
"""
try:
event_name = self.EVENTS[event]
except KeyError:
raise Exception(f"Event {event} not found")
net_event = NetworkEvent(event_name)
self.conn.remove_callback(net_event, callback_id)
self._remove_intercept(self.callbacks[callback_id])
del self.callbacks[callback_id]
self.subscriptions[event_name].remove(callback_id)
if len(self.subscriptions[event_name]) == 0:
params: dict[str, Any] = {}
params["events"] = [event_name]
self.conn.execute(command_builder("session.unsubscribe", params))
del self.subscriptions[event_name]
[docs]
def clear_request_handlers(self) -> None:
"""Clear all request handlers from the network."""
for event_name in self.subscriptions:
net_event = NetworkEvent(event_name)
for callback_id in self.subscriptions[event_name]:
self.conn.remove_callback(net_event, callback_id)
self._remove_intercept(self.callbacks[callback_id])
del self.callbacks[callback_id]
params: dict[str, Any] = {}
params["events"] = [event_name]
self.conn.execute(command_builder("session.unsubscribe", params))
self.subscriptions = {}
[docs]
def add_auth_handler(self, username: str, password: str) -> int:
"""Add an authentication handler to the network.
Args:
username: The username to authenticate with.
password: The password to authenticate with.
Returns:
int: callback id
"""
event = "auth_required"
def _callback(request: Request) -> None:
request._continue_with_auth(username, password)
return self.add_request_handler(event, _callback)
[docs]
def remove_auth_handler(self, callback_id: int) -> None:
"""Remove an authentication handler from the network.
Args:
callback_id: The callback id to remove.
"""
event = "auth_required"
self.remove_request_handler(event, callback_id)
[docs]
class Request:
"""Represents an intercepted network request."""
def __init__(
self,
network: Network,
request_id: Any,
body_size: int | None = None,
cookies: Any = None,
resource_type: str | None = None,
headers: Any = None,
headers_size: int | None = None,
method: str | None = None,
timings: Any = None,
url: str | None = None,
) -> None:
self.network = network
self.request_id = request_id
self.body_size = body_size
self.cookies = cookies
self.resource_type = resource_type
self.headers = headers
self.headers_size = headers_size
self.method = method
self.timings = timings
self.url = url
[docs]
def fail_request(self) -> None:
"""Fail this request."""
if not self.request_id:
raise ValueError("Request not found.")
params: dict[str, Any] = {"request": self.request_id}
self.network.conn.execute(command_builder("network.failRequest", params))
[docs]
def continue_request(
self,
body: Any = None,
method: str | None = None,
headers: Any = None,
cookies: Any = None,
url: str | None = None,
) -> None:
"""Continue after intercepting this request."""
if not self.request_id:
raise ValueError("Request not found.")
params: dict[str, Any] = {"request": self.request_id}
if body is not None:
params["body"] = body
if method is not None:
params["method"] = method
if headers is not None:
params["headers"] = headers
if cookies is not None:
params["cookies"] = cookies
if url is not None:
params["url"] = url
self.network.conn.execute(command_builder("network.continueRequest", params))
def _continue_with_auth(self, username: str | None = None, password: str | None = None) -> None:
"""Continue with authentication.
Args:
username: The username to authenticate with.
password: The password to authenticate with.
Note:
If username or password is None, it attempts auth with no credentials.
"""
params: dict[str, Any] = {}
params["request"] = self.request_id
if not username or not password: # no credentials is valid option
params["action"] = "default"
else:
params["action"] = "provideCredentials"
params["credentials"] = {"type": "password", "username": username, "password": password}
self.network.conn.execute(command_builder("network.continueWithAuth", params))