from typing import Callable, Union
from episimmer.read_file import ReadVDConfiguration
from episimmer.vulnerability_detection.base import AgentVD, EventVD
from episimmer.world import World
[docs]class VD():
"""
Class for implementing all types of Vulnerability Detection modules.
Args:
vd_config_obj: ReadVDConfiguration object
world_obj: World object of simulation
"""
def __init__(self, vd_config_obj: ReadVDConfiguration, world_obj: World):
self.vd_config_obj: ReadVDConfiguration = vd_config_obj
self.world_obj: World = world_obj
[docs] def get_class(self, name: str) -> Callable:
"""
Returns the class of the vulnerability detection module.
Args:
name: Name of module given as a string
Returns:
Class of vulnerability detection module
"""
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod
[docs] def add_target_name(self) -> str:
"""
Returns the name of python file for given user target.
Returns:
Name of python file
"""
path = ''
if self.vd_config_obj.target.lower() == 'agent':
path = 'agent_vd'
elif self.vd_config_obj.target.lower() == 'event':
path = 'event_vd'
else:
raise Exception('Input valid target')
return path
[docs] def get_algorithm(self) -> Callable:
"""
Returns the class of the vulnerability detection algorithm based on user input in vd_config.txt.
Returns:
Class of the vulnerability detection algorithm
"""
class_path = 'vulnerability_detection' + '.' + self.add_target_name(
) + '.' + self.vd_config_obj.algorithm
algorithm_class = self.get_class(class_path)
return algorithm_class
[docs] def run_vul_detection(self) -> None:
"""
Runs the vulnerability detection algorithm
"""
algorithm_class = self.get_algorithm()
algo_object = algorithm_class(self.world_obj,
self.vd_config_obj.parameter_dict)
algo_object.run_detection()
self.run_output(algo_object)
[docs] def run_output(self, algo_object: Union[AgentVD, EventVD]) -> None:
"""
This function by default prints 10 maximum and 10 minimum scores post detection.
Args:
algo_object: Object of the vulnerability detection module
"""
if (self.vd_config_obj.output_mode == 'Default'
or self.vd_config_obj.output_mode == ''):
algo_object.print_default_output(10)
[docs] def run_preprocess(self) -> None:
"""
Functionality to be run pre detection
"""
pass
[docs] def run_postprocess(self) -> None:
"""
Functionality to be run post detection
"""
pass