Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pytm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
]

import sys
from types import ModuleType
from typing import Any, Dict


from .json import load, loads
from .pytm import (
Expand Down Expand Up @@ -51,16 +54,16 @@
)


def pdoc_overrides():
result = {"pytm": False, "json": False, "template_engine": False}
def pdoc_overrides() -> Dict[str, Any]:
result: Dict[str, Any] = {"pytm": False, "json": False, "template_engine": False}
mod = sys.modules[__name__]
for name, klass in mod.__dict__.items():
if not isinstance(klass, type):
continue
for i in dir(klass):
if i in ("check", "dfd", "seq"):
result[f"{name}.{i}"] = False
attr = getattr(klass, i, {})
attr: Any = getattr(klass, i, {})
if isinstance(attr, var) and attr.doc != "":
result[f"{name}.{i}"] = attr.doc
return result
Expand Down
24 changes: 13 additions & 11 deletions pytm/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import sys
from typing import Any, TextIO, Dict, Union, List


from .pytm import (
TM,
Expand All @@ -18,23 +20,23 @@
)


def loads(s):
def loads(s: str) -> "TM":
"""Load a TM object from a JSON string *s*."""
result = json.loads(s, object_hook=decode)
result: Any = json.loads(s, object_hook=decode)
if not isinstance(result, TM):
raise ValueError("Failed to decode JSON input as TM")
return result


def load(fp):
def load(fp: TextIO) -> "TM":
"""Load a TM object from an open file containing JSON."""
result = json.load(fp, object_hook=decode)
result: Any = json.load(fp, object_hook=decode)
if not isinstance(result, TM):
raise ValueError("Failed to decode JSON input as TM")
return result


def decode(data):
def decode(data: Dict[str, Any]) -> Union[Dict[str, Any], TM]:
if "elements" not in data and "flows" not in data and "boundaries" not in data:
return data

Expand All @@ -49,9 +51,9 @@ def decode(data):
return TM(data.pop("name"), **data)


def decode_boundaries(flat):
boundaries = {}
refs = {}
def decode_boundaries(flat: List[Dict[str, Any]]) -> Dict[str, Boundary]:
boundaries: Dict[str, Boundary] = {}
refs: Dict[str, str] = {}
for i, e in enumerate(flat):
name = e.pop("name", None)
if name is None:
Expand All @@ -70,8 +72,8 @@ def decode_boundaries(flat):
return boundaries


def decode_elements(flat, boundaries):
elements = {}
def decode_elements(flat: List[Dict[str, Any]], boundaries: Dict[str, Boundary]) -> Dict[str, Any]:
elements: Dict[str, Any] = {}
for i, e in enumerate(flat):
klass = getattr(sys.modules[__name__], e.pop("__class__", "Asset"))
name = e.pop("name", None)
Expand All @@ -89,7 +91,7 @@ def decode_elements(flat, boundaries):
return elements


def decode_flows(flat, elements):
def decode_flows(flat: List[Dict[str, Any]], elements: Dict[str, Any]) -> None:
for i, e in enumerate(flat):
name = e.pop("name", None)
if name is None:
Expand Down
9 changes: 5 additions & 4 deletions pytm/report_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, List, Union

class ReportUtils:
@staticmethod
def getParentName(element):
def getParentName(element: Any) -> str:
from pytm import Boundary
if (isinstance(element, Boundary)):
parent = element.inBoundary
Expand All @@ -14,7 +15,7 @@ def getParentName(element):


@staticmethod
def getNamesOfParents(element):
def getNamesOfParents(element: Any) -> Union[List[str], str]:
from pytm import Boundary
if (isinstance(element, Boundary)):
parents = [p.name for p in element.parents()]
Expand All @@ -23,15 +24,15 @@ def getNamesOfParents(element):
return "ERROR: getNamesOfParents method is not valid for " + element.__class__.__name__

@staticmethod
def getFindingCount(element):
def getFindingCount(element: Any) -> str:
from pytm import Element
if (isinstance(element, Element)):
return str(len(list(element.findings)))
else:
return "ERROR: getFindingCount method is not valid for " + element.__class__.__name__

@staticmethod
def getElementType(element):
def getElementType(element: Any) -> str:
from pytm import Element
if (isinstance(element, Element)):
return str(element.__class__.__name__)
Expand Down