By Kirindi Choi, Ljubomir Buturovic, Roland Luethy, Inflammatix, Inc.
Introduction
Recently, generative artificial intelligence (AI) models for text, images, and video have made major progress and achieved worldwide attention among experts and the public, including initial applications in medicine [1]. The application of generative AI to problems in genomics (the study of genes and their functions) has, understandably, been less visible, but has nevertheless important applications. In this blog, we assess several open-source and commercial tools which can be used to generate high-quality genomic data and discuss potential applications. We focus on transcriptomic applications (a subfield of genomics), with tabular data representing gene expression (the measurements of abundance of gene products in cells). Transcriptomics has significant applications in bioinformatics research and increasingly in clinical care as a new class of diagnostics and prognostics [2-5].
One of the main use cases for synthetic genomic and transcriptomic data is sharing data while preserving privacy. Some ideas include the following: an organization may wish to organize a Kaggle competition for its transcriptomic problem by using synthetic data based on patient data, thereby preserving the privacy of the patient data; or an organization may send synthetic data to a software vendor to report and reproduce a bug, again without sending sensitive data.
Another potential use case is to improve classification accuracy by adding synthetic tabular data to the training set for Machine Learning (ML) models. However, in the available literature, we have not seen convincing evidence of this approach being successful. Thus, we believe that this use case remains hypothetical.
In this blog, we compare the quality of the transcriptomic synthetic data created using different generative AI tools.
Methods
We selected and compared two open-source and two commercial synthetic data generators (available through Python API).
We selected the following open-source Python tools:
CTGAN (conditional tabular generative adversarial network) [6] from SDV (Synthetic Data Vault) [7] and
Gaussian Copula (Gaussian Multivariate) [8] in SDV Copulas library.
We selected the following commercial solutions:
LSTM and ACTGAN (an alternate implementation of CTGAN) cloud-based APIs from Gretel.ai ([9], [12], [13], [14]).
To evaluate quality of the synthesized data, we used two different metrics:
SDMetrics (synthetic data metrics) [10] from SDV.
Cross-validation AUROC (Area Under Receiver Operating Characteristic) of a binary logistic regression classifier trained on real (positive class) and corresponding synthetic (negative class) data. The idea is that high-quality synthetic data should be difficult to discriminate (classify) from real data, therefore such data should have an AUROC of approximately 0.5. This quality metric has a low false-negative rate: data which fail the metric are unlikely to be high-quality. However, data with AUROC close to 0.5 may still not be high quality
Per [11], we also applied duplicate detection steps after the data were synthesized: we detected and discarded any duplicates within the synthesized dataset, and detected and discarded any synthesized data that were replicates of the real data.
For the data synthesis, we used default values for the parameter settings of the data synthesizers except batch size and number of epochs (Table 1).
We used a real data set with 9,654 patient samples and used the expression levels of 29 genes for each sample [5]. We synthesized 6 sets of data: 4 datasets using SDV tool, and 2 datasets using Gretel tool. The synthetic datasets had 1,000 samples each. The SDV datasets were created on the same EC2 instance with four vCPUs, whereas Gretel.ai is a cloud-based service. Since ACTGAN is a variant of CTGAN, we also ran a CTGAN with the same hyper parameter settings as ACTGAN’s default parameters that use larger network than CTGAN’s default parameters.
We computed an AUROC quality metric as follows. For each set of synthesized data, we used the synthesized data and the real data as the training set (consisting of 10,654 samples in total) wherein synthetic data was considered positive class and real data was considered negative class. We then estimated a cross-validation AUROC for the said training set using scikit-learn Logistic Regression model for binary classification and Optuna [15] hyperparameter search. The AUROC reported corresponds to the highest AUROC found by the hyperparameter search. Ideally, a classifier should not be able to distinguish synthetic data from real. Thus, the corresponding cross-validation AUROC should be as close as possible to 0.5.
The second metric was computed using the SDMetrics package. It evaluates the marginal distributions and pairwise trends between columns. Its overall quality score is an average of all metric scores (i.e., KSComplement, TVComplement, CorrelationSimilarity and CategoryCoverage). The score ranges from 0 to 1 with 1 meaning the best quality.
Results
We performed the duplication check on 10,000 synthetic samples and found no duplicates within synthetic data nor copies of real data. Thus, this QC (Quality Control) step may be redundant.
Besides fine-grained and overall scores, SDMetrics offers convenient comparison visualizations including density plot per column between the synthesized data and the real data.
Figure 1: Density plot for a feature synthesized with Gretel LSTM (#6 below).
Figure 2: Density plot for the same feature as in Fig. 1 synthesized with CTGAN (#3 below).
All synthetic datasets except Gretel LSTM were easily and highly distinguishable from the real data with AUROCs >= 0.87. For example, for CTGAN with 50 epochs and batch size of 100, a linear classifier could accurately distinguish between real and synthetic data with high accuracy (AUROC = 0.963). Overall, Gretel LSTM performed best in generating data that mimics the set of real data, with an AUROC of 0.623 and the highest SDMetrics overall quality score of 0.949. A distant second was SDV’s CTGAN with 100 epochs and larger dimensions with SDMetrics quality score of 0.927 and was easily distinguishable from real data with AUROC of 0.87. Notably, between the two synthetic datasets generated using Gretel, the LSTM tool was substantially better than their ACTGAN tool.
The AUROC metric proved to be very useful. It was significantly more intuitive and familiar to users than the SDMetric Overall Quality score, yet the rankings of the methods obtained by the two metrics were virtually identical.
Table 1: Parameters and performance metrics of synthetic data tools.
We observed that turning verbose mode on for CTGAN added excessive amount of time to our data synthesis process and thus we kept verbose off for the runs as shown above.
As expected, we also observed that for the three CTGAN runs with different hyper parameter values, the performance results differed. This suggests some hyperparameter optimization for data synthesizer training can be beneficial for a given set of real data.
In the future we plan to add the Transformer-based NEMO tabular data generator from NVIDIA [11] to our evaluations.
Conclusion
We found that the LSTM synthetic data generator from Gretel.ai is the best among the six solutions that we compared by a wide margin of 0.25 AUROC points between Gretel.ai and the next best software. Interestingly, it is based on LSTM, which to our knowledge has not been widely used for generating synthetic non-temporal tabular data.
Our findings are only based on one internal transcriptomic dataset with numeric features and may not generalize to other data. Nevertheless, we think it is an important data point because the Gretel LSTM was substantially better than any other tool, meaning that it may be inherently superior. We also found the AUROC to be an especially useful and intuitive quality metric in benchmarking performance of these different techniques.
References
Shah NH, Entwistle D, Pfeffer MA. Creation and Adoption of Large Language Models in Medicine. JAMA. 2023 Aug 7.
Sparano JA, Gray RJ, Makower DF, Pritchard KI, Albain KS, Hayes DF, Geyer Jr CE, Dees EC, Goetz MP, Olson Jr JA, Lively T. Adjuvant chemotherapy guided by a 21-gene expression assay in breast cancer. New England Journal of Medicine. 2018 Jul 12;379(2):111-21.
Alexander EK, Kennedy GC, Baloch ZW, Cibas ES, Chudova D, Diggans J, Friedman L, Kloos RT, LiVolsi VA, Mandel SJ, Raab SS. Preoperative diagnosis of benign thyroid nodules with indeterminate cytology. New England Journal of Medicine. 2012 Aug 23;367(8):705-15.
Pham MX, Teuteberg JJ, Kfoury AG, Starling RC, Deng MC, Cappola TP, Kao A, Anderson AS, Cotts WG, Ewald GA, Baran DA. Gene-expression profiling for rejection surveillance after cardiac transplantation. New England Journal of Medicine. 2010 May 20;362(20):1890-900.
Brakenridge SC, Chen UI, Loftus T, Ungaro R, Dirain M, Kerr A, Zhong L, Bacher R, Starostik P, Ghita G, Midic U. Evaluation of a multivalent transcriptomic metric for diagnosing surgical sepsis and estimating mortality among critically ill patients. JAMA Network Open. 2022 Jul 1;5(7):e2221520-.
By Roland Luethy and Ljubomir Buturovic, Inflammatix, Inc.
Introduction
Typically, a clinical classifier generates a score that corresponds to likelihood of disease presence or future outcome. In order to facilitate decision-making, the score is sometimes converted to a discrete classification label using decision thresholds [1]. For binary classification, there is a single threshold that can be chosen using a trade-off between sensitivity and specificity based on a receiver operating characteristic (ROC) curve or similar. However, it is often desirable to partition the range of output scores into multiple bands, corresponding to different likelihoods of the disease/outcome, which in turn requires multiple thresholds that cannot be determined by the inspection of ROC curves. To our knowledge, no effective solution to this problem has been described.
Here, we developed a genetic optimization algorithm for the determination of decision thresholds for multiple output bands, called Genetic Algorithm Thresholds (GAT) (the term “genetic” applies to an optimization method, not genome). We have applied this method to a three-class classifier which diagnoses the presence and type of infection in patients suspected of an acute infection and/or sepsis. The classifier uses the gene expression profile of patients’ immune response as input features and produces scores for a patient sample indicating the probability of bacterial infection, probability of viral infection and probability of no infection.
Methods
To improve interpretability and guide treatment actions, each probability (score) is partitioned to likelihood bands such that each probability range [0, 1] is divided into five disjoint decision intervals (Fig. 1). Thus, in our application, each of the three probabilities of disease (bacterial, viral, and no infection) is divided into five bands. For example, if a given patient’s scores are in the very likely band for the viral infection (very high probability of viral infection) and very unlikely band for bacterial infection (very low probability of bacterial infection), treatment with antibiotics may not be beneficial.
Figure 1: Partitioning of a classifier probability in five decision bands. Each of the three class probabilities computed by the classifier is partitioned into five such bands. The thresholds are specific to the probabilities and are computed independently for each class.
The decision thresholds, which define the bands, should be chosen using clinically meaningful criteria. For example, we could specify that both the confidence in and the number of patients assigned to the “extreme” bands (lowest or highest probability band) should be as high as possible because those are the clinically most actionable bands. We represent stringency for each band using diagnostic likelihood ratios (LR) [2]. For example, for the bacterial and viral scores the clinical considerations (obtained through input from clinician community) suggest LR for the lowest band should be at most 0.075 and for the highest band at least 7.5. Furthermore, for the test to demonstrate utility at a population level, a meaningful percentage of patients should fall in the extreme bands and few patients should result in the non-informative middle (indeterminate) band (e.g., at least 50% patients should be in the extreme bands and at most 10% in the middle band). To balance these requirements and to find thresholds that generate bands (decision thresholds) meeting them, we developed a tool using a genetic algorithm with a cost function encapsulating the desired criteria.
Dataset Overview
The training set for the classifier consists of 29 genes (input features) profiled in 3159 patients from 42 clinical studies, assayed on gene expression microarrays. The validation set comprises 741 samples from 9 independent clinical studies, using the same 29 input features measured on the Nanostring nCounter(R) platform [3]. To ensure consistency and accuracy, both the training and validation sets were normalized using samples from healthy patients, with the Nanostring platform serving as a reference. The classifier is an advanced version of a previously published one [4].
Algorithm Overview
We implemented the Genetic Algorithm Thresholds (GAT) algorithm, shown below, in python using the DEAP library [5, 6]. We apply the steps 1 through 4 independently to each output class (i.e., set of bacterial, viral, and non-infected probabilities) to optimize the corresponding decision thresholds. At the completion of the analysis, 12 thresholds are generated (4 for each class).
The initial population for the evolutionary (genetic) algorithm is randomly generated. A set of ‘chromosomes,’ each representing a potential solution to the problem, is created (the term “chromosome” applies to a vector representing a set of thresholds, not a DNA fiber in cell nucleus). The chromosome corresponding to a solution has four values, representing the 4 thresholds needed to split the probabilities into 5 bands.
The fitness of each chromosome in the population is evaluated using a fitness function. The function assigns a fitness score to each chromosome based on how well it fits the desired criteria for LR1, LR5, percentage of patients in band 1 and band 5 combined (coverage) and percentage of patients in band 3. The coverage is only considered if the value is below the target value, so that coverage exceeding the target is not penalized.
A new generation is created by selecting parents according to their fitness. Offspring are created using crossover and mutation operations. The individuals with the top 20% fitness are always kept in the population.
Steps 2 and 3 are repeated for a given number of iterations until there is no improvement of the best solution.
Results
We trained the classifier on the training set and applied the tuned classifier to the validation set. Then we applied GAT to the validation set probabilities. The results are summarized in Table 1. It shows the target values that were used for GAT and the actual values with the best thresholds found by GAT for the bacterial and viral probabilities.
Target value
Achieved value (bacterial)
Achieved value (viral)
LR1 (lowest band)
0.075
0.089
0.101
LR5 (highest band)
7.5
8.688
9.678
% in band 1 and 5
50
69.8
69.1
% in band 3
7.5
8.6
10
Table 1: Target values and achieved values for the three-class infectious disease classifier. GAT does not guarantee that all target values will be achieved. Nevertheless, overall performance was deemed adequate.
Figure 2 shows the bacterial and viral probabilities and the thresholds for the classifier, where blue and red dots represent patients with bacterial or viral infection, respectively. The green dots represent patients with inflammation that are not caused by bacterial nor viral infections. The dotted lines represent the thresholds determined using GAT. The thresholds let us assign each sample to one of five bacterial and five viral likelihood bands. Samples that fall in bacterial band 1 are very unlikely to be bacterial infections, whereas samples in bacterial band 5 are very likely bacterial infections. Similarly, samples that fall in viral band 1 are very unlikely to be viral infections, whereas samples in viral band 5 are very likely viral infections. Figure 3 shows that most patients with bacterial infections are in bacterial band 5 and viral band 1 and most patients with viral infections are in bacterial band 1 and viral band 5.
Figure 2: Assignment of bacterial and viral probabilities to likelihood bands. The dotted lines are the thresholds determined using GAT.
Figure 3: Frequency of patients with bacterial infections, viral infections, and no infections in each of the five bacterial and viral bands defined by GAT. “Coverage” is percent of patients in bands 1 and 5.
Conclusion
We found that GAT enables efficient optimization of decision thresholds using an arbitrary number of decision regions (bands) and an arbitrary fitness function. We intend to use this method to define decision thresholds for the TriVerityTM Acute Infection and Sepsis Test, currently in development at Inflammatix.
References
https://www.canassistbreast.com/sample-report.php
Hayden SR, Brown MD. Likelihood ratio: a powerful tool for incorporating the results of a diagnostic test into clinical decision making. Annals of emergency medicine. 1999 May 1;33(5):575-80.
Kulkarni MM. Digital multiplexed gene expression analysis using the NanoString nCounter system. Current protocols in molecular biology. 2011 Apr;94(1):25B-10.
Mayhew MB, Buturovic L, Luethy R, Midic U, Moore AR, Roque JA, Shaller BD, Asuni T, Rawling D, Remmel M, Choi K. A generalizable 29-mRNA neural-network classifier for acute bacterial and viral infections. Nature communications. 2020 Mar 4;11(1):1177.
Fortin FA, De Rainville FM, Gardner MA, Parizeau M, Gagné C. DEAP: Evolutionary algorithms made easy. The Journal of Machine Learning Research. 2012 Jul 1;13(1):2171-5.
Wirsansky E. Hands-on genetic algorithms with Python: applying genetic algorithms to solve real-world deep learning and artificial intelligence problems. Packt Publishing Ltd; 2020 Jan 31.
By Nandita Damaraju, Ljubomir Buturovic, Inflammatix, Inc.
Are neural networks better than other machine learning algorithms on small tabular data?
Introduction
Deep Neural Networks (DNN) models outperform conventional machine learning algorithms on unstructured data modalities such as images, text, and audio. However, their application to modeling structured tabular data has not been as successful. Tabular Datasets are used in a variety of domains including medicine, finance, manufacturing, climate science. Many applications (for example, medicine) also use tabular datasets that are smaller (< 10 K samples) because samples are expensive to acquire. Given the wide use of such datasets, it would be beneficial to know which ML algorithms perform best when applied to small tabular data. Recent studies to that end (Gorishniy, 2021; Gorishniy, 2022) have been inconclusive because
In this study, we carried out an experimental comparison across different algorithms that include neural network-based algorithms against non-neural network machine learning algorithms on such tabular data.
Methodology
Dataset Overview
There are many standard data sets to compare new deep learning architectures against existing baselines such as MNIST, CIFAR, and ImageNet for image classification. In contrast, there are no established standard tabular data sets. This is addressed partly by the Penn Machine Learning Benchmark (Olson, 2017), that provides the largest collection of diverse, public benchmark datasets for evaluating new machine learning methods. The database includes many modalities of data, including tabular datasets. In this study, we used a combination of PMLB datasets and internal Inflammatix datasets. We focused on numeric features and classification problems because we have the most expertise working with such data. All the datasets used for this analysis are outlined in Table 1.
We used 6 datasets from the PMLB database. The criteria were as follows:
fewer than 10,000 samples
either binary or three-class multiclass classification tasks
the number of samples corresponding to the least common class label was above 2% of the total dataset.
Only the numerical features corresponding to these datasets were used for further analyses.
In addition to the 6 datasets above, we also use 5 datasets generated internally at Inflammatix. These datasets are also tabular and use gene expression values as features. We use gene expression data to build classifiers to diagnose various infectious diseases. In this context, each input sample (vector) represents one person (patient), features correspond to genes, and the classes correspond to different diseases or disease states. Each feature value is a measurement of the abundance of the corresponding gene in each tissue type (e.g., blood sample), using a suitable measurement platform such as qPCR. The task is to classify the patient’s disease accurately using the gene measurements as input features.
dataset
#features
#classes
#samples
class 0
class 1
class 2
ext_ann_thyroid
21
3
7200
92.6%
2.3%
5.1%
ext_appendicitis
7
2
106
80.2%
19.8%
–
ext_clean2
166
2
6598
84.6%
15.4%
–
ext_phoneme
5
2
5404
70.7%
29.3%
–
ext_schizo
11
3
340
22.9%
58.8%
18.2%
ext_spectf
44
2
349
27.2%
72.8%
–
int_BVN
29
3
3159
32.5%
33.2%
34.3%
int_COV
6
2
705
96.7%
3.3%
–
int_SEV
29
2
2622
94.6%
5.4%
–
int_TNF
7
2
136
47.8%
52.2%
–
int_VIB
7
2
1123
40.4%
59.6%
–
Table 1: Overview of datasets used
Algorithm Overview
We compared the performance of 7 different learning algorithms on these datasets as summarized in Table 2. We include three learning algorithms based on decision trees, XGBoost, LightGBM and Random Forest. We focused on the comparison of these tree-based classifiers with neural network classifiers MLP (Multi-Layer Perceptron) and TabNet (Arık, 2021). We also included Logistic Regression and Support Vector Machine algorithms as benchmarks.
Algorithm
Overview
Abbr.
LightGBM
Distributed gradient boosting framework using decision trees
LGBM
Logistic Regression
Linear method for classification, often a good benchmark
Deep Learning framework based on sequential attention
TabNet
XGBoost
Gradient boosting framework that uses tree-based methods
XGB
Table 2: Overview of learning algorithms used. Algorithms that use neural networks are highlighted in bold
Evaluating Models
We evaluated the methods using Area Under Receiver Operating Characteristic curve (AUROC) and balanced accuracy as the scoring metrics (Provost, 2003). AUROC is a metric to measure the degree of separability of different classes, for a given classifier. While it is traditionally defined for binary classification, a macro-average of the one-vs-one AUROC values was used for multiclass problems (Hand and Till, 2000). Balanced Accuracy is the arithmetic mean of sensitivity and specificity. It is used for both binary and multiclass classification with imbalanced classes. Since the datasets have a class imbalance, balanced accuracy was preferred over accuracy.
For each ML (Machine Learning) method, we first scaled the features of the datasets and then performed a comprehensive hyperparameter search of each of the ML method’s parameters using random 5-fold cross-validation to find the best parameters (according to AUROC) for each ML method on each data set.
Comparing model performance
Performance Overview
To compare the performance for each method in Table 2, the AUROC and Balanced Accuracy were computed for each algorithm and tabulated in Table 3 and 4. The corresponding 95% Confidence Intervals are also indicated. We ignored TabNet for all further analysis because of the increased run time and inferior performance.
Dataset
LGBM
LOGR
MLP
RBF
RF
XGB
ext_ann_thyroid
1.0 (1.0,1.0)
0.994 (0.991,0.995)
0.988 (0.982,0.994)
0.986 (0.981,0.99)
0.999 (0.998,1.0)
1.0 (1.0,1.0)
ext_appendicitis
0.861 (0.744,0.963)
0.861 (0.687,0.96)
0.873 (0.774,0.956)
0.864 (0.709,0.963)
0.833 (0.692,0.961)
0.854 (0.698,0.974)
ext_clean2
1.0 (1.0,1.0)
0.979 (0.975,0.983)
0.999 (0.999,1.0)
1.0 (0.999,1.0)
0.995 (0.993,0.996)
0.999 (0.999,1.0)
ext_phoneme
0.949 (0.942,0.955)
0.814 (0.802,0.829)
0.93 (0.922,0.938)
0.939 (0.93,0.947)
0.929 (0.919,0.935)
0.955 (0.948,0.962)
ext_schizo
0.629 (0.568,0.691)
0.648 (0.591,0.706)
0.648 (0.596,0.719)
0.644 (0.598,0.693)
0.591 (0.547,0.643)
0.628 (0.565,0.684)
ext_spectf
0.948 (0.919,0.975)
0.895 (0.851,0.946)
0.954 (0.927,0.982)
0.938 (0.887,0.982)
0.951 (0.92,0.977)
0.939 (0.9,0.966)
int_BVN
0.945 (0.938,0.953)
0.905 (0.893,0.914)
0.95 (0.942,0.957)
0.953 (0.944,0.96)
0.942 (0.932,0.952)
0.948 (0.94,0.956)
int_coverity
0.902 (0.81,0.956)
0.903 (0.87,0.945)
0.903 (0.822,0.954)
0.871 (0.728,0.95)
0.869 (0.804,0.936)
0.905 (0.834,0.966)
int_severity
0.915 (0.888,0.933)
0.922 (0.891,0.938)
0.93 (0.911,0.951)
0.902 (0.876,0.929)
0.905 (0.873,0.932)
0.924 (0.901,0.94)
int_tnfa
0.816 (0.716,0.904)
0.787 (0.686,0.877)
0.812 (0.683,0.909)
0.815 (0.736,0.901)
0.786 (0.675,0.892)
0.797 (0.688,0.868)
int_virabac
0.942 (0.916,0.958)
0.941 (0.915,0.954)
0.949 (0.929,0.962)
0.945 (0.921,0.961)
0.941 (0.919,0.956)
0.943 (0.919,0.959)
Table 3: Area Under the Curve (AUROC) performance. The best result for each dataset is highlighted in bold
A cursory analysis of the AUROC results (Table 3) indicates that MLP is the top performer. Since many AUROC values are high, we also looked at balanced accuracy scores. The balanced accuracy values reveal a similar picture: MLP still seems to outperform the other methods (Table 4).
Dataset
LGBM
LOGR
MLP
RBF
RF
XGB
ext_ann_thyroid
0.992 (0.978,0.999)
0.961 (0.946,0.971)
0.887 (0.852,0.914)
0.849 (0.827,0.886)
0.99 (0.982,0.995)
0.992 (0.983,0.998)
ext_appendicitis
0.762 (0.644,0.899)
0.775 (0.592,0.911)
0.715 (0.596,0.863)
0.786 (0.639,0.932)
0.762 (0.644,0.899)
0.762 (0.644,0.933)
ext_clean2
0.981 (0.974,0.987)
0.923 (0.911,0.933)
0.984 (0.978,0.989)
0.991 (0.986,0.995)
0.948 (0.937,0.964)
0.974 (0.967,0.982)
ext_phoneme
0.868 (0.855,0.882)
0.748 (0.735,0.76)
0.83 (0.815,0.844)
0.857 (0.843,0.872)
0.819 (0.802,0.828)
0.878 (0.866,0.895)
ext_schizo
0.333 (0.333,0.333)
0.422 (0.35,0.493)
0.457 (0.397,0.548)
0.387 (0.348,0.445)
0.382 (0.342,0.431)
0.337 (0.327,0.357)
ext_spectf
0.852 (0.801,0.925)
0.793 (0.709,0.867)
0.888 (0.845,0.938)
0.847 (0.776,0.907)
0.855 (0.777,0.915)
0.845 (0.785,0.913)
int_BVN
0.818 (0.797,0.831)
0.764 (0.745,0.779)
0.844 (0.828,0.859)
0.848 (0.829,0.863)
0.818 (0.801,0.835)
0.835 (0.818,0.846)
int_coverity
0.515 (0.485,0.586)
0.832 (0.767,0.923)
0.515 (0.488,0.595)
0.499 (0.496,0.5)
0.702 (0.594,0.884)
0.5 (0.5,0.5)
int_severity
0.714 (0.659,0.784)
0.851 (0.805,0.884)
0.578 (0.538,0.614)
0.509 (0.497,0.525)
0.784 (0.733,0.84)
0.577 (0.546,0.617)
int_tnfa
0.737 (0.64,0.862)
0.696 (0.586,0.772)
0.766 (0.662,0.852)
0.741 (0.662,0.833)
0.721 (0.581,0.824)
0.726 (0.635,0.811)
int_virabac
0.88 (0.848,0.901)
0.869 (0.848,0.889)
0.883 (0.857,0.904)
0.879 (0.853,0.903)
0.876 (0.848,0.899)
0.881 (0.854,0.905)
Table 4: Balanced Accuracy performance. The best result for each dataset is highlighted in bold
Ranking Classifiers
Figure 1: Ranking of methods based on AUROC (left) and Balanced Accuracy (right)
Our main goal was to compare neural networks with non-NN algorithms. To facilitate answering that question, we visualized the relative performance of the learning algorithms by ranking their performance for each dataset and plotting the averaged ranks for each method. This method is inspired by Friedman’s M statistic and is used commonly for algorithm comparison (Brazdil, 2000). The numerical values of mean ranks for each of the learning algorithms are shown in the rounded boxes. The plots for AUROC and Balanced Accuracy rank MLP the highest followed by LGBM.
Pairwise comparison of Classifiers
Figure 2: Pairwise comparison of performance based on AUROC (left) and balanced accuracy (right)
The plots in Figure 2 can also be used to determine if one learning algorithm is better than another. For example, MLP (classifier A) outperforms XGB (classifier B) about 63.6 % of the time based on AUROC and about 72.7% of the time based on balanced accuracy. This is consistent with the ranks for MLP and XGB in Figure 1.
Discussion
Our key observation is that Multi-Layer-Perceptron outperformed gradient-boosted-tree-based learning algorithms (XGBoost and LightGBM), though the differences were small. Support Vector Machine, Logistic Regression, and Random Forest were inferior to other methods considered.
We are aware that XGBoost and other decision tree-based algorithms have a significant following among ML practitioners who work with tabular data (Shwartz-Ziv, 2022) and were wondering why it seemed to underperform on the datasets that we considered. To that end, we examined recent publications comparing the classifiers (Gorishniy, 2021; Gorishniy, 2022, Borisov 2021), and noticed that the datasets they used fell into two distinct groups:
datasets with exclusively continuous features (we refer to them as homogenous tabular datasets)
datasets with a combination of discrete and continuous features (we refer to them as heterogeneous tabular datasets).
Further we noticed that in these studies neural networks outperformed tree-based algorithms on 12 out of 12 homogenous tabular datasets, whereas XGBoost outperformed neural networks on 4 out of 5 heterogenous tabular datasets. We also recognize that tree-based algorithms are better suited at handling discrete features by design than neural networks. Based on these considerations, we hypothesize that XGBoost may be superior for heterogeneous tabular datasets, and neural networks are best suited for homogeneous tabular datasets.
This analysis has limitations. Since we only look at classification tasks on small tabular datasets with numerical features, our conclusions might not extend to regression tasks, heterogenous or larger tabular datasets. It is possible that the hyperparameter tuning could be improved, potentially affecting the algorithms rankings and the conclusions. Since we use 5-fold random cross-validation for practical reasons, our findings could be strengthened by introducing independent validation data, where available, and evaluating other types of cross-validation.
In conclusion, we did not find convincing evidence to claim that neural networks outperform non-neural network algorithms on small tabular datasets, but we also did not observe superior performance by decision tree-based methods (XGBoost, LGBM, RandomForest). We hypothesize that decision tree-based learning algorithms may be best for datasets that have a combination of continuous and categorical features, but not necessarily for datasets with exclusively numerical features. This hypothesis requires further research.
Identification of optimal hyperparameters is an integral component for building robust accurate machine learning models. Hyperparameters control various aspects of a classification model such as learning rate, regularization, and model architecture. These hyperparameters influence the time required to train the model and ultimately its performance. Simpler methods such as logistic regression tend to have fewer hyperparameters to tune compared to advanced methods like deep neural networks. As the models become more complex, the possible space of hyperparameter configurations (HCs) increases exponentially and evaluating hyperparameters becomes computationally intensive.
Traditional methods for sifting through the many possible HCs include Grid Search and Random Search. However, these approaches suffer when evaluating models with many hyperparameters as the search space becomes exponentially larger. Methods such as Bayesian Optimization and Hyperband, use relatively lesser computational resources to achieve a comparable result.
A commonly used approach to identify the best performing HC is estimating performance of a model trained using the HC by k-fold cross validation The k-fold cross-validation procedure divides a limited dataset into k non-overlapping parts. Each of the k parts is used as a held back validation set whilst all other parts are collectively used for training. A total of k models are fit and evaluated on the k holdout test sets and average performance metrics are reported. This concept is also used to assess the performance of the HCs by training models with different hyperparameters and evaluating them. However, k-fold cross validation further increases the computational complexity of the problem as it requires a model to be trained and tested k times for each HC.
In this blog, we aim to further reduce the time to identify a top performing HC by truncating cross validation runs that do not show promise. For some HCs, the performance metric tends to be low for most folds in a cross validation run. It is unlikely that these HCs would be top contenders. For such HCs, it would not make sense to train models for later folds if the performance is poor in the first few folds. For example, let’s say a 5-fold CV is used to iterate through HCs for a Multi-Layer Perceptron (MLP). It is possible that the first fold for a particular HC, results in a very low AUC value of 0.6. The AUC values of the next 4 folds wouldn’t be able to compensate for the low AUC value of 0.6 in the first fold, thus unlikely to remain a top contender. Hence it would be a waste of computational time to compute the AUC values for the remaining 4 folds.
If the computation of poor performing HCs is truncated, computational resources can be spent on more promising HCs and the time to iterate through the hyperparameter space will be reduced. This could be useful in cases where there is a high proportion of poor performing HCs or if the algorithm takes a significant amount of time to iterate through each HC.
Data
At Inflammatix, we use gene expression data to build classifiers to diagnose various infectious diseases. In this context, each input sample (vector) represents one person (patient), features correspond to genes, and the classes correspond to different diseases or disease states. Each feature value is a measurement of the abundance of the corresponding gene in a given tissue type (e.g., blood sample), using a suitable measurement platform such as qPCR. The task is to classify the patient’s disease accurately using the gene measurements as input features.
We used two datasets for this blog post (see table below). We used the samples in the training set with cross validation to identify top performing models. For both studies, we used the separate held-out validation set to evaluate these models and create the final model.
Dataset
#Features
#Classes
#Training Samples
#Validation Samples
BVN
29
3
3159
741
Severity
29
2
2622
1060
Approach
The modified cross validation approach is described as follows (see Figure 1) :
Establish a performance threshold for the algorithm, for example an AUC of 0.7
For each HC
Compute the chosen metric for the first fold
If the metric is above the threshold, proceed to the next fold
If the metric is below the threshold, do not compute the metrics for the remaining folds and skip to next HC
Continue this process, either until the value for a fold is below the threshold or until all folds are completed.
Figure 1: Flowchart to describe the adaptive CV approach
This would result in one single best performing hyperparameter configuration based on the cross-validation performance for a given method. The performance of the top performing HC for different methods are then evaluated on a separate held-out validation set.
Identifying good candidates
How do we identify scenarios where this approach could result in a significant reduction in training time?
We plotted the pooled cross validation AUC values (X-axis) against the fraction of AUCs that are below these values (Y-axis) for 4 different algorithms RBF, MLP, LOGR and XGBoost (see Figure 2). An ideal candidate would have (1) a significant proportion of HCs below a cutoff threshold, and (2) a gradual increase in AUC values as opposed to a sharp rise as a sharp rise would make it harder to select a robust threshold value.
MLPs had over 60% of HCs below an AUC of 0.75. XGBoost on the other hand demonstrated a sharp rise and a narrow range of AUC values (see Figure 2), making it hard to determine a suitable cutoff and hence making XGB unsuitable for this approach. LOGR and RBF also had only ~20% of their values below 0.75, hence making it harder to justify the usage of these algorithms for less than a 20% reduction in time. Therefore, we chose MLPs to test the efficacy of this approach.
Figure 2: Fraction of HCs (y-axis) above threshold values (x-axis)
Experimental Setup
We used MLPs with both 5 and 10-fold cross validation experiments. We tuned hyperparameters using Hyperband with either AUC or accuracy as the performance metric. We compared the running time, Cross Validation (CV) performance and the held-out-validation data (Validation) performance for various thresholds.
Results
The following plots depict four experiments. We evaluated the running time, cross validation (CV) performance, and the held-out-validation data (Validation) performance for different threshold values as shown on the x-axis.
The green bars represent the corresponding running time on the y axis on the left. The distance between a green bar and the horizontal green baseline performance value denotes the reduction in running time for a particular threshold. The lower the height of the green bars, the higher the reduction in running time.
The y-axis on the right indicates the performance, with validation in blue and CV performance in red. The green blue and red horizontal lines illustrate the baseline (without any threshold) running time, validation, and CV performance. Ideally, we would want the dots to be very close to their respective baseline values represented by the horizontal lines since we do not want to see a reduction in performance.
Experiment 1: 5FOLD CV + mAUC as metric + BVN multi class dataset
The first experiment used the multi-class BVN dataset and a 5-fold CV with AUC as the performance metric. With no threshold the baseline took 112 minutes to run as shown by the green horizontal line. With a moderate threshold of about 0.75 there was a 9% reduction in running time and similar performance to the baseline. With a stricter threshold of 0.8, there was a 17% reduction in running time and a similar performance to the baseline. With a very aggressive threshold of 0.85 there was a 43% reduction in running time and a reduction in the CV performance, but no reduction in the validation performance.
Experiment 2: 10 FOLD CV + Accuracy as metric + BVN multi class dataset
For a 10-fold CV some of the folds did not have all the classes present. In those cases, AUC is not defined and hence we chose accuracy as the metric for 10-Fold CV. The baseline with no threshold took 180 minutes to run. With a moderate threshold of about 0.7 there was a 66% reduction in running time and a small reduction in CV performance. With a stricter threshold of 0.8 and 0.85 there was a 70% reduction in running time, but a higher reduction in the performance values.
For a binary classification dataset (S3) the run time without a threshold was 101 minutes. With a moderate threshold of 0.7, there was a 10% reduction in running time and with an aggressive threshold of 0.875 there was a 31% reduction in running time with a modest drop in performance.
For the same binary classification dataset, with a 10-fold CV and accuracy as the choice of metric, the baseline took 158 minutes. With a moderate threshold of 0.7, there was a 5% reduction in running time and with an aggressive threshold of 0.875 there was a 10% reduction in running time and no reduction in performance.
Discussion
With a modest fixed cutoff, only 10Fold cross validation with accuracy as the choice of metric, in the multi-class BVN dataset showed a reduction of 66% in running time without a reduction in performance on both the training and validation sets. The adaptiveCV approach however did not prove to be beneficial in the remaining experiments, either due to a small reduction (less than 30 %) in running time or poorer performance on the cross-validation or validation data.
At the crux of this approach lies the ability to choose a threshold accurately, that can weed out poor performers quickly and retain the promising HCs. The optimal threshold values tend to vary for different experiment settings. This requires running these experiments multiple times to determine a good cutoff value which negates the time gain from the approach. A potential future research direction would be to set a variable threshold that updates itself based on the values of the previous iterations. With such a variable threshold, perhaps more time could be saved without a compromise on the performance.
Conclusion
We did not find an advantage in using the adaptive CV approach. The concept appears intuitively appealing at first, but our experiments to date do not support it.