# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import dataclasses
import json
import logging
import threading
from ssl import CERT_NONE
from threading import Thread
from time import sleep
from websocket import WebSocketApp
from selenium.common import WebDriverException
def _snake_to_camel(name: str) -> str:
"""Convert snake_case field name to camelCase for BiDi protocol."""
parts = name.split("_")
return parts[0] + "".join(p.title() for p in parts[1:])
class _BiDiEncoder(json.JSONEncoder):
"""JSON encoder for BiDi dataclass instances.
Converts snake_case field names to camelCase, strips ``None`` values,
and flattens a ``properties`` field (e.g. ``PointerCommonProperties``)
directly into its parent action dict as required by the BiDi spec.
"""
def _convert(self, value):
"""Recursively convert a value, handling nested dataclasses, lists, and dicts."""
if dataclasses.is_dataclass(value) and not isinstance(value, type):
return self.default(value)
if isinstance(value, list):
return [self._convert(item) for item in value]
if isinstance(value, dict):
return {k: self._convert(v) for k, v in value.items()}
return value
def default(self, o):
if dataclasses.is_dataclass(o) and not isinstance(o, type):
result = {}
for f in dataclasses.fields(o):
value = getattr(o, f.name)
# Skip None values unless the field is explicitly marked
# retain_none=True in its metadata (e.g. for required-but-nullable
# BiDi fields that must be sent as JSON null rather than omitted).
if value is None and not f.metadata.get("retain_none"):
continue
camel_key = _snake_to_camel(f.name)
# Flatten PointerCommonProperties fields inline into the parent
if camel_key == "properties" and dataclasses.is_dataclass(value):
for pf in dataclasses.fields(value):
pv = getattr(value, pf.name)
if pv is not None:
result[_snake_to_camel(pf.name)] = self._convert(pv)
else:
result[camel_key] = self._convert(value)
return result
return super().default(o)
logger = logging.getLogger(__name__)
[docs]
class WebSocketConnection:
_max_log_message_size = 9999
def __init__(self, url, timeout, interval):
if not isinstance(timeout, (int, float)) or timeout < 0:
raise WebDriverException("timeout must be a positive number")
if not isinstance(interval, (int, float)) or timeout < 0:
raise WebDriverException("interval must be a positive number")
self.url = url
self.response_wait_timeout = timeout
self.response_wait_interval = interval
self.callbacks = {}
self.session_id = None
self._id = 0
self._id_lock = threading.Lock()
self._messages = {}
self._started = False
self._start_ws()
self._wait_until(lambda: self._started)
[docs]
def close(self):
self._ws_thread.join(timeout=self.response_wait_timeout)
self._ws.close()
self._started = False
self._ws = None
[docs]
def execute(self, command):
with self._id_lock:
self._id += 1
current_id = self._id
payload = self._serialize_command(command)
payload["id"] = current_id
if self.session_id:
payload["sessionId"] = self.session_id
data = json.dumps(payload, cls=_BiDiEncoder)
logger.debug(f"-> {data}"[: self._max_log_message_size])
self._ws.send(data)
self._wait_until(lambda: current_id in self._messages)
if current_id not in self._messages:
raise WebDriverException(f"Timed out waiting for response to BiDi command {current_id}")
response = self._messages.pop(current_id)
if "error" in response:
error = response["error"]
if "message" in response:
error_msg = f"{error}: {response['message']}"
raise WebDriverException(error_msg)
else:
raise WebDriverException(error)
else:
result = response["result"]
return self._deserialize_result(result, command)
[docs]
def add_callback(self, event, callback):
event_name = event.event_class
if event_name not in self.callbacks:
self.callbacks[event_name] = []
def _callback(params):
callback(event.from_json(params))
self.callbacks[event_name].append(_callback)
return id(_callback)
on = add_callback
[docs]
def remove_callback(self, event, callback_id):
event_name = event.event_class
if event_name in self.callbacks:
for callback in self.callbacks[event_name]:
if id(callback) == callback_id:
self.callbacks[event_name].remove(callback)
return
def _serialize_command(self, command):
return next(command)
def _deserialize_result(self, result, command):
try:
_ = command.send(result)
raise WebDriverException("The command's generator function did not exit when expected!")
except StopIteration as exit:
return exit.value
def _start_ws(self):
def on_open(ws):
self._started = True
def on_message(ws, message):
self._process_message(message)
def on_error(ws, error):
logger.debug(f"error: {error}")
ws.close()
def run_socket():
if self.url.startswith("wss://"):
self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True)
else:
self._ws.run_forever(suppress_origin=True)
self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error)
self._ws_thread = Thread(target=run_socket, daemon=True)
self._ws_thread.start()
def _process_message(self, message):
message = json.loads(message)
logger.debug(f"<- {message}"[: self._max_log_message_size])
if "id" in message:
self._messages[message["id"]] = message
if "method" in message:
params = message["params"]
for callback in self.callbacks.get(message["method"], []):
Thread(target=callback, args=(params,), daemon=True).start()
def _wait_until(self, condition):
timeout = self.response_wait_timeout
interval = self.response_wait_interval
while timeout > 0:
result = condition()
if result:
return result
else:
timeout -= interval
sleep(interval)