rex¶
main logical entrypoint for ReX.
Functions¶
|
Makes an attempt to preprocess input data as required for the model. |
|
Loads input data from filepath and does preprocessing. |
|
|
|
Predicts classification of input data, using given prediction function. |
|
Calculates ResponsibilityMaps for input data using given args. |
|
Analyzes an Explanation. |
|
Takes a CausalArgs object and model information and returns a Explanation. |
Takes a CausalArgs object and gets the prediction function and model shape. |
|
|
Takes a CausalArgs object and returns a Explanation. |
Module Contents¶
- rex.try_preprocess(args, model_shape, device)¶
Makes an attempt to preprocess input data as required for the model.
Data preprocessing is based on file extension and (possibly) user-defined mode. File extensions in
[".jpg", ".jpeg", ".png", ".tif", ".tiff"]are treated as images, “.npy” and “.mat” are treated as Numpy arrays, and “.nii” are treated as nifti files. For any other file extension, we create aDataobject without pre-processing.- Parameters:
args (rex_xai.input.config.CausalArgs) – configuration values for ReX
model_shape (Tuple[int]) – shape of the input tensor of the model, as returned by
get_prediction_func_from_args()device (torch.device) – as returned by
get_device()
- Returns:
the processed input data
- Return type:
- rex.load_and_preprocess_data(model_shape, device, args)¶
Loads input data from filepath and does preprocessing.
Uses a custom preprocesssing function if this is defined in
args.script.preprocess, otherwisetry_preprocess().- Parameters:
model_shape (Tuple[int]) – shape of the input tensor of the model, as returned by
get_prediction_func_from_args()device (torch.device) – as returned by
get_device()args (rex_xai.input.config.CausalArgs) – configuration values for ReX
- Returns:
the processed input data
- Return type:
- rex.validate_shape(data, model_shape)¶
- Parameters:
data (rex_xai.input.input_data.Data)
- Return type:
rex_xai.input.input_data.Data
- rex.predict_target(data, args, prediction_func)¶
Predicts classification of input data, using given prediction function.
Uses
prediction_functo identify the classification of the input data and return this as the target classification for ReX.- Parameters:
data (rex_xai.input.input_data.Data) – processed input data object
args (rex_xai.input.config.CausalArgs) – configuration values for ReX
prediction_func – prediction function for the model
- Returns:
the predicted target classification and confidence
- Return type:
- rex.calculate_responsibility(data, args, prediction_func, custom_height=None, custom_width=None)¶
Calculates ResponsibilityMaps for input data using given args.
Runs
causal_explanation()forargs.itersiterations, and returns a ResponsibilityMaps object and a dictionary containing some statistics about the calculation process. The ResponsibilityMaps object by default only includes the responsibility map that matches the classification of the input data. Setkeep_all_mapstoTrueto retain all maps.- Parameters:
data (rex_xai.input.input_data.Data) – processed input data object
args (rex_xai.input.config.CausalArgs) – configuration values for ReX
prediction_func – prediction function for the model
keep_all_maps – whether to retain all
ResponsibilityMaps, or just the one that matches the target classification.
- Returns:
tuple containing
ResponsibilityMaps: ResponsibilityMaps for the given data, prediction function, and args.
dict: statistics for the call of this function that generated the ResponsibilityMaps object
- Return type:
tuple[rex_xai.responsibility.resp_maps.ResponsibilityMaps, dict]
- rex.analyze(exp, data_mode)¶
Analyzes an Explanation.
Analyzes the area ratio, entropy difference, insertion and deletion curves for an Explanation object, prints them, and returns them.
- Parameters:
- Returns:
tuple containing
area (float)
entropy (float)
insertion_curve (float)
deletion_curve (float)
- Return type:
- rex._explanation(args, model_shape, prediction_func, device, db=None, path=None)¶
Takes a CausalArgs object and model information and returns a Explanation.
Takes a CausalArgs object, model shape and prediction function and returns an Explanation. Depending on the input
args, optionally produces output plots, analyses the output explanation, and/or writes results to a database.- Parameters:
args (rex_xai.input.config.CausalArgs) – configuration values for ReX
model_shape (Tuple[int]) – shape of the input tensor of the model, as returned by
get_prediction_func_from_args()prediction_func (Callable) – as returned by
get_prediction_func_from_args()device (torch.device) – as returned by
get_device()db (sqlalchemy.orm.Session | None) – None or as returned by
initialise_rex_db()path (str | None)
- Returns:
An
Explanationobject containing the causal reponsibility explanation calculated using the givenargs.- Return type:
- rex.get_prediction_func_from_args(args)¶
Takes a CausalArgs object and gets the prediction function and model shape.
If
args.scriptspecifies a prediction function and model shape, returns these. Otherwise gets the prediction function and model shape from the provided model file.- Parameters:
args (rex_xai.input.config.CausalArgs) – configuration values for ReX
- Returns:
tuple containing
prediction_funcmodel_shape
- Raises:
RuntimeError – if an onnx inference instance cannot be created from the provided model file.
- rex.explanation(args, device, db=None)¶
Takes a CausalArgs object and returns a Explanation.
Takes a CausalArgs object and returns either an Explanation, or a list of Explanations if the input
args.pathis a directory rather than a path to a single file.- Parameters:
args (rex_xai.input.config.CausalArgs) – configuration values for ReX
device (torch.device) – as returned by
get_device()db (sqlalchemy.orm.Session | None) – None or as returned by
initialise_rex_db()
- Returns:
An
Explanationobject containing the causal reponsibility explanation calculated using the givenargs.- Return type: