Hi, here is the code we use to determine the best possible F1 Threshold on the toy dataset: ``` from collections import defaultdict import itertools import json import os import random import traceback import nibabel as nib import numpy as np from sklearn import metrics from scipy.ndimage import label as find_label from scipy.ndimage import center_of_mass from tqdm import tqdm def get_best_thres( toy_pred_dir=None, toy_label_dir=None, ): toy_pred_list = [] toy_label_list = [] for fl_ in os.listdir(toy_pred_dir): pred_file = os.path.join(toy_pred_dir, fl_) label_file = os.path.join(toy_label_dir, fl_) toy_pred_list.append(nib.load(pred_file).get_fdata(dtype=float)) toy_label_list.append(nib.load(label_file).get_fdata(dtype=float)) toy_pred_array = np.vstack(toy_pred_list).astype(float) toy_label_array = np.vstack(toy_label_list).astype(float) get_f1_score_clean_list(toy_pred_list, toy_label_list, 0.5) unique_preds = np.unique(toy_pred_array) if len(unique_preds) > 2: best_f1 = 0 best_thres = 0 for bin_thres in tqdm(np.linspace(toy_pred_array.min(), toy_pred_array.max(), 20)): f1s = get_f1_score_clean_list(toy_pred_list, toy_label_list, bin_thres) if f1s > best_f1: best_f1 = f1s best_thres = bin_thres print(bin_thres, f1s) step_size = (toy_pred_array.max() - toy_pred_array.min()) / 20 max_f1, reconst_thres = find_best_val( toy_pred_list, toy_label_list, get_f1_score_clean_list, max_steps=4, val_range=(best_thres-step_size, best_thres+step_size), ) else: reconst_thres = np.mean(unique_preds) return reconst_thres def get_f1_score_clean_list(anomaly_map_list, seg_objects_list, bin_thres, size_thres=600): tps , fps, fns = 0,0,0 for pred, label in zip(anomaly_map_list, seg_objects_list): _, tp, fp, fn = get_f1_score(pred, label, bin_thres, size_thres=size_thres) tps += tp fps += fp fns += fn if tps + fps + fns != 0: f1_score = 2 * tps / (2 * tps + fps + fns) else: f1_score = 0 return f1_score def get_f1_score(anomaly_map, seg_objects, bin_thres, size_thres): pred_thres = anomaly_map > bin_thres pred_labeled, n_labels = find_label(pred_thres, np.ones((3, 3, 3))) seg_labeled, seg_labels = find_label(seg_objects, np.ones((3, 3, 3))) label_counts = np.bincount(pred_labeled.flatten()) matched_dict = defaultdict(bool) fp = 0 fn = 0 tp = 0 for lbl_idx in range(1, n_labels + 1): if label_counts[lbl_idx] < size_thres: continue # print(label_counts[lbl_idx], np.sum(pred_labeled == lbl_idx)) pred_thres_copy = pred_thres.copy() pred_thres_copy[pred_labeled != lbl_idx] = 0 x, y, z = center_of_mass(pred_thres_copy) x, y, z = int(x), int(y), int(z) if seg_objects[x, y, z] != 0: gt_sum = np.sum(seg_labeled == seg_labeled[x, y, z]) pred_sum = np.sum(pred_thres_copy) up_thres = gt_sum * 2 low_thres = gt_sum // 2 if pred_sum < up_thres and pred_sum > low_thres: matched_dict[seg_labeled[x, y, z]] = True else: fp += 1 else: fp += 1 for seg_ob_id in np.unique(seg_labeled): if seg_ob_id != 0 and not matched_dict[seg_ob_id]: fn += 1 elif seg_ob_id != 0 and matched_dict[seg_ob_id]: tp += 1 if tp + fp + fn != 0: f1_score = 2 * tp / (2 * tp + fp + fn) else: f1_score = 0 return f1_score, tp, fp, fn def find_best_val(x, y, val_fn, val_range=(0, 1), max_steps=4, step=0, max_val=0, max_point=0): print(step, max_val, max_point) if step == max_steps: return max_val, max_point if val_range[0] == val_range[1]: val_range = (val_range[0], 1) bottom = val_range[0] top = val_range[1] center = bottom + (top - bottom) * 0.5 q_bottom = bottom + (top - bottom) * 0.25 q_top = bottom + (top - bottom) * 0.75 val_bottom = val_fn(x, y, q_bottom) val_top = val_fn(x, y, q_top) if val_bottom > val_top: if val_bottom > max_val: max_val = val_bottom max_point = q_bottom return find_best_val(x, y, val_fn, val_range=(bottom, center), step=step + 1, max_steps=max_steps, max_val=max_val, max_point=max_point) else: if val_top > max_val: max_val = val_bottom max_point = q_bottom return find_best_val(x, y, val_fn, val_range=(center, top), step=step + 1, max_steps=max_steps, max_val=max_val, max_point=max_point) ``` Cheers, David

Created by David Zimmerer d.zimmerer

Code to determine the best threshold page is loading…