186 lines
7.0 KiB
Python
186 lines
7.0 KiB
Python
import re
|
|
import typing as t
|
|
from dataclasses import dataclass
|
|
from dataclasses import field
|
|
|
|
from .converters import ValidationError
|
|
from .exceptions import NoMatch
|
|
from .exceptions import RequestAliasRedirect
|
|
from .exceptions import RequestPath
|
|
from .rules import Rule
|
|
from .rules import RulePart
|
|
|
|
|
|
class SlashRequired(Exception):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class State:
|
|
"""A representation of a rule state.
|
|
|
|
This includes the *rules* that correspond to the state and the
|
|
possible *static* and *dynamic* transitions to the next state.
|
|
"""
|
|
|
|
dynamic: t.List[t.Tuple[RulePart, "State"]] = field(default_factory=list)
|
|
rules: t.List[Rule] = field(default_factory=list)
|
|
static: t.Dict[str, "State"] = field(default_factory=dict)
|
|
|
|
|
|
class StateMachineMatcher:
|
|
def __init__(self, merge_slashes: bool) -> None:
|
|
self._root = State()
|
|
self.merge_slashes = merge_slashes
|
|
|
|
def add(self, rule: Rule) -> None:
|
|
state = self._root
|
|
for part in rule._parts:
|
|
if part.static:
|
|
state.static.setdefault(part.content, State())
|
|
state = state.static[part.content]
|
|
else:
|
|
for test_part, new_state in state.dynamic:
|
|
if test_part == part:
|
|
state = new_state
|
|
break
|
|
else:
|
|
new_state = State()
|
|
state.dynamic.append((part, new_state))
|
|
state = new_state
|
|
state.rules.append(rule)
|
|
|
|
def update(self) -> None:
|
|
# For every state the dynamic transitions should be sorted by
|
|
# the weight of the transition
|
|
state = self._root
|
|
|
|
def _update_state(state: State) -> None:
|
|
state.dynamic.sort(key=lambda entry: entry[0].weight)
|
|
for new_state in state.static.values():
|
|
_update_state(new_state)
|
|
for _, new_state in state.dynamic:
|
|
_update_state(new_state)
|
|
|
|
_update_state(state)
|
|
|
|
def match(
|
|
self, domain: str, path: str, method: str, websocket: bool
|
|
) -> t.Tuple[Rule, t.MutableMapping[str, t.Any]]:
|
|
# To match to a rule we need to start at the root state and
|
|
# try to follow the transitions until we find a match, or find
|
|
# there is no transition to follow.
|
|
|
|
have_match_for = set()
|
|
websocket_mismatch = False
|
|
|
|
def _match(
|
|
state: State, parts: t.List[str], values: t.List[str]
|
|
) -> t.Optional[t.Tuple[Rule, t.List[str]]]:
|
|
# This function is meant to be called recursively, and will attempt
|
|
# to match the head part to the state's transitions.
|
|
nonlocal have_match_for, websocket_mismatch
|
|
|
|
# The base case is when all parts have been matched via
|
|
# transitions. Hence if there is a rule with methods &
|
|
# websocket that work return it and the dynamic values
|
|
# extracted.
|
|
if parts == []:
|
|
for rule in state.rules:
|
|
if rule.methods is not None and method not in rule.methods:
|
|
have_match_for.update(rule.methods)
|
|
elif rule.websocket != websocket:
|
|
websocket_mismatch = True
|
|
else:
|
|
return rule, values
|
|
|
|
# Test if there is a match with this path with a
|
|
# trailing slash, if so raise an exception to report
|
|
# that matching is possible with an additional slash
|
|
if "" in state.static:
|
|
for rule in state.static[""].rules:
|
|
if websocket == rule.websocket and (
|
|
rule.methods is None or method in rule.methods
|
|
):
|
|
if rule.strict_slashes:
|
|
raise SlashRequired()
|
|
else:
|
|
return rule, values
|
|
return None
|
|
|
|
part = parts[0]
|
|
# To match this part try the static transitions first
|
|
if part in state.static:
|
|
rv = _match(state.static[part], parts[1:], values)
|
|
if rv is not None:
|
|
return rv
|
|
# No match via the static transitions, so try the dynamic
|
|
# ones.
|
|
for test_part, new_state in state.dynamic:
|
|
target = part
|
|
remaining = parts[1:]
|
|
# A final part indicates a transition that always
|
|
# consumes the remaining parts i.e. transitions to a
|
|
# final state.
|
|
if test_part.final:
|
|
target = "/".join(parts)
|
|
remaining = []
|
|
match = re.compile(test_part.content).match(target)
|
|
if match is not None:
|
|
rv = _match(new_state, remaining, values + list(match.groups()))
|
|
if rv is not None:
|
|
return rv
|
|
|
|
# If there is no match and the only part left is a
|
|
# trailing slash ("") consider rules that aren't
|
|
# strict-slashes as these should match if there is a final
|
|
# slash part.
|
|
if parts == [""]:
|
|
for rule in state.rules:
|
|
if rule.strict_slashes:
|
|
continue
|
|
if rule.methods is not None and method not in rule.methods:
|
|
have_match_for.update(rule.methods)
|
|
elif rule.websocket != websocket:
|
|
websocket_mismatch = True
|
|
else:
|
|
return rule, values
|
|
|
|
return None
|
|
|
|
try:
|
|
rv = _match(self._root, [domain, *path.split("/")], [])
|
|
except SlashRequired:
|
|
raise RequestPath(f"{path}/") from None
|
|
|
|
if self.merge_slashes and rv is None:
|
|
# Try to match again, but with slashes merged
|
|
path = re.sub("/{2,}?", "/", path)
|
|
try:
|
|
rv = _match(self._root, [domain, *path.split("/")], [])
|
|
except SlashRequired:
|
|
raise RequestPath(f"{path}/") from None
|
|
if rv is None:
|
|
raise NoMatch(have_match_for, websocket_mismatch)
|
|
else:
|
|
raise RequestPath(f"{path}")
|
|
elif rv is not None:
|
|
rule, values = rv
|
|
|
|
result = {}
|
|
for name, value in zip(rule._converters.keys(), values):
|
|
try:
|
|
value = rule._converters[name].to_python(value)
|
|
except ValidationError:
|
|
raise NoMatch(have_match_for, websocket_mismatch) from None
|
|
result[str(name)] = value
|
|
if rule.defaults:
|
|
result.update(rule.defaults)
|
|
|
|
if rule.alias and rule.map.redirect_defaults:
|
|
raise RequestAliasRedirect(result, rule.endpoint)
|
|
|
|
return rule, result
|
|
|
|
raise NoMatch(have_match_for, websocket_mismatch)
|