Como identificar e tratar adequadamente um conjunto de dados com classes desbalanceadas
Introdução
Caso você já tenha trabalhado com problemas de classificação, é possível que já tenha se deparado com um conjunto de dados desbalanceados, mesmo sem perceber num primeiro momento.
A situação mais comum é separar o seu conjunto de dados em treinamento e teste, e logo após treinar o modelo pela primeira vez você observa um valor de acurácia altíssimo, como 99,9%. Para um olhar pouco acostumado, nesse momento pode parecer que o seu modelo está com uma performance excelente, que nada mais deve ser feito e que já pode ser aplicado em produção.
No entanto, a verdade que se esconde por trás desse resultado não é tão agradável: é bem provável que você esteja lidando com um conjunto de dados com classes desbalanceadas.
O que são dados desbalanceados?
Um conjunto desbalanceado consiste em um agrupamento de dados em que as classes (variáveis dependentes) possuem números diferentes de entradas. Porém, com uma grande quantidade de entradas, as diferenças devem ser expressivas para que essa característica cause prejuízos ao modelo.
Como um parâmetro arbitrário, é possível dizer que a performance de um modelo é mais seriamente afetada em conjuntos com classes desbalanceadas além de 90/10, isto é, um conjunto onde a classe minoritária possui menos de 10% das entradas totais¹.
A imagem abaixo ilustra um caso onde o conjunto é fortemente desbalanceado: em um registro de transações de compras com cartão de crédito, as classes foram divididas em fraudes e transações reais.
Nesse exemplo, das 284807 entradas totais, apenas 492 eram fraudes, ou seja, a classe minoritária tinha somente 0.17% de representatividade.
Podemos calcular essa taxa facilmente com a biblioteca pandas:
import pandas as pd
df = pd.read_csv("credit_card_data.csv")
# DISTRIBUIÇÃO DE CLASSES
classes = df.Class.value_counts()
print(classes)
# 0: 284315
# 1: 492
# REPRESENTATIVIDADE DA CLASSE MINORITÁRIA
print(f"{classes[1]/classes.sum()*100:.2f}%")
# 0.17%
Como corrigir esse problema?
Na literatura, podemos encontrar diversos métodos para tratar o desbalanceamento. O primeiro método, apesar de parecer simplista, consiste em coletar mais dados, especialmente da classe minoritária.
O problema é que essa abordagem nem sempre poderá ser implementada, pois alguns conjuntos de dados são desbalanceados devido à sua natureza, como dados de fraudes em cartões de crédito, identificação de doenças raras, classificação de spam em e-mails, entre muitas outras.
Partindo do pressuposto que não temos mais dados disponíveis, surgem outras possibilidades, como seleção e adequação de modelos que lidam melhor com esse tipo de dados, ensembles, redução de dimensionalidade, penalização do modelo com diferentes funções de custo e métodos de resampling.
Embora o foco deste artigo seja no último tipo de tratamento, eu te encorajo a procurar mais sobre as outras técnicas.
Escolhendo métricas adequadas
O exemplo que eu citei na introdução é um caso da “falácia da taxa-base” (Base Rate Fallacy²). Quando avaliando a performance de um modelo em cima de dados desbalanceados, uma métrica como 99% de acurácia, na verdade, não indica que o nosso modelo classifica erroneamente 1% das vezes.
Como sabemos, a acurácia corresponde à quantidade de acertos dividida pela quantidade total de testes. Em um caso onde a classe minoritária tem, por exemplo, 0,5% das entradas, a elevada acurácia releva a distribuição original dos dados, portanto muitas vezes o modelo classificaria a classe minoritária erroneamente, mas esses erros seriam relevados por conta de sua pequena influência no contexto global.
A falácia ocorre, então, sempre que olhamos os valores como esse e esquecemos da distribuição dos nossos dados, a “taxa-base”.
Por sorte, existe uma série de outras métricas que podemos utilizar, algumas mais adequadas para conjuntos desbalanceados. Uma bastante comum, porém não extremamente efetiva, é a matriz de confusão.
Por meio dessa métrica, podemos avaliar o desempenho do nosso modelo segundo tipos específicos de classificação: positivos e negativos verdadeiros/falsos.
Ainda, utilizando as informações da matriz de confusão, podemos utilizar métricas secundárias, como precision, recall e f1-score.
Com a biblioteca sklearn é possível calcular os valores da matriz de confusão:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
conf_mat = [value for value in confusion_matrix(y_test,y_pred)] # y_pred é o resultado do modelo treinado
# PLOTAR A MATRIZ COMO HEATMAP
fig, ax = plt.subplots()
sns.heatmap(data=conf_mat, annot=True, ax=ax)
plt.tight_layout()
Na matriz de confusão, os valores apresentados são absolutos, ou seja, a quantidade de julgamentos dentro do total de entradas.
Quando tratamos com dados fortemente desbalanceados, uma abordagem que me parece mais interessante é normalizar essa matriz de confusão, para que os valores apresentados sejam relativos aos totais de cada classe.
from sklearn.metrics import confusion_matrix
matrix = confusion_matrix(y_test, y_pred)
# DIVIDIR OS VALORES PELA SOMA PARA A NORMALIZAÇÃO
norm_matrix = [value / value.sum() for value in matrix]
# COMO RESULTADO, OS QUADRANTES DA MATRIZ SOMARÃO 1 (100%) A CADA LINHA
As métricas derivadas da matriz de confusão também podem ser mostradas com a biblioteca sklearn:
from sklearn.metrics import classification_report
print(classification_report(y_test,y_pred)
# Mostra os valores de precision, recall e f1-score
Os diferentes problemas com dados
Cada problema de ciência de dados possui suas próprias características e, portanto, o nosso modelo deve se adaptar às características que esperamos encontrar no resultado final.
Por exemplo, enquanto construímos um modelo para classificar e-mails recebidos como spam ou não, pode ser mais adequado permitir que algumas mensagens indesejadas cheguem à sua caixa de entrada do que deixar um e-mail importante cair na caixa de spam (mais falsos negativos).
Por outro lado, digamos que o objetivo do nosso modelo é identificar casos de uma doença rara. Assim, pode ser mais desejável que a quantidade de falsos negativos seja mínima, evitando que um paciente com a doença fique sem o tratamento adequado.
No mundo real, não é possível e nem adequado que um modelo seja totalmente preciso, e o caminho mais comum é que avaliemos o custo de oportunidade das nossas previsões³.
Custo de Oportunidade
Existem duas métricas bastante relevantes para analisar os custos de oportunidade das decisões de um modelo⁴: a curva de Precision/Recall e a Receiver Operating Characteristic (ROC).
Ao pesquisar sobre essas duas curvas, podemos encontrar opiniões bastante divergentes quanto à sua aplicabilidade em conjuntos de dados desbalanceados: enquanto alguns preferem a primeira, outros optam pela segunda.
Porém, qual o significado de cada uma delas?
A ROC é uma curva que mede a qualidade da separabilidade de um modelo de classificação, ou seja, o quão bom o modelo é em distinguir valores de diferentes classes. Ao plotar a ROC, temos nos eixos a taxa de positivos verdadeiros (TPR) e a taxa de falsos positivos (FPR), para diferentes thresholds. Da ROC podemos extrair o valor da área embaixo da curva (AUC) que, quanto mais próxima de 1, melhor será a classificação do modelo.
Implementação com sklearn:
from sklearn.metrics import roc_curve, roc_auc_score
# OBTER OS VALORES DA ROC
fpr, tpr, _ = roc_curve(y_test, y_pred)
# PLOTAR A ROC COMO GRÁFICO
plt.plot(fpr,tpr)
plt.show()
# OBTER O VALOR DA AUC
print(f"ROC AUC Score: {roc_auc_score(y_test,y_pred):.2f}")
# >>> ROC AUC Score: 0.89
De forma similar à ROC, a curva de Precision/Recall (PRC), mostra os valores de Precision e Recall obtidos pelo modelo para diferentes thresholds. Assim como a métrica anterior, é possível calcular a AUC e tomar decisões baseadas nos valores percebidos.
Uma AUC próxima de 1 significa que o modelo possui tanto alta precisão quanto recall, ou seja, o modelo possui baixas taxas de ambos falsos positivos e falsos negativos.
Implementação com sklearn:
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve, auc
# PLOTAR A PRC NO GRÁFICO
display = PrecisionRecallDisplay.from_predictions(y_test, y_pred)
y_pred = y_pred[:, 1] # GUARDAR AS PREVISÕES APENAS DA CLASSE POSITIVA 1
precision, recall, _ = precision_recall_curve(y_test, y_pred) # GUARDAR OS VALORES DA PRC
AUC = auc(recall, precision) # CALCULAR A AUC
print(f"AUC: {AUC:.2f}")
# >>> AUC: 0.89
O grande argumento a favor de utilizar a PRC como métrica para avaliação de conjuntos desbalanceados é a sua baseline (“linha-base”), que se adequa à distribuição do conjunto⁵. Isso ocorre porque os valores TN (negativos verdadeiros), que fazem parte da classe majoritária, não são levados em consideração durante o cálculo da curva, de forma distinta ao que acontece com a ROC, onde os valores TN são necessários para o cálculo da TPR e FPR.
Tratando o desbalanceamento
Como mencionei, existe uma grande variedade de métodos para tratar dados desbalanceados, e a biblioteca imblearn foi criada com o intuito de abrigar muitas dessas técnicas. Porém, para limitar o escopo deste artigo, falaremos principalmente sobre métodos de resampling.
Resampling significa redistribuir as classes, de modo que o produto final é um conjunto de dados menos desbalanceado, ou quiçá até totalmente balanceado.
Para isso, existem duas grandes metodologias: undersampling e oversampling. Como você pode imaginar, no primeiro nós tiramos entradas da classe majoritária, enquanto no outro adicionamos entradas à classe minoritária.
Alguns dos métodos existentes são:
- Undersampling: Random Under Sampling, Tomek Links, Nearest Neighbour Cleaning.
- Oversampling: Random Over Sampling, ADASYN, SMOTE e variantes.
Para exemplificar, tratarei aqui sobre o Random Under Sampling e SMOTE.
Random Under Sampling
Este é um dos métodos mais simples, mas bastante efetivo. O tratamento consiste em excluir entradas da classe majoritária de forma aleatória, como o próprio nome sugere. Durante a aplicação do método, é possível selecionar qual será a proporção desejada após. Por exemplo, selecionando um valor de sampling_strategy de 0.5, teremos um número de entradas da classe minoritária igual a 50% da majoritária.
SMOTE
SMOTE, ou Synthetic Minority Oversampling Technique, é uma técnica muito conhecida de oversampling, que possui uma série de métodos derivados do conceito original. Sem entrar em detalhes técnicos (que podem ser encontrados no artigo original⁶), o método sintetiza novas entradas para a class minoritária, com características que imitam as originais, porém com pequenas variações.
Implementando as técnicas com a biblioteca imblearn:
from imblearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
over = SMOTE(sampling_strategy=0.1,random_state=42)
under = RandomUnderSampler(sampling_strategy=0.5,random_state=42)
steps = [('o',over),('u',under)]
pipeline = Pipeline(steps=steps) # CRIAR UMA PIPELINE DE RESAMPLING
# APLICAR A TRANSFORMAÇÃO NO CONJUNTO DE TREINO
X_train, y_train = pipeline.fit_resample(X_train,y_train)
Nesse exemplo, criamos uma Pipeline com duas etapas: over e undersampling, com os métodos descritos acima. Com os valores de sampling_strategy, o resultado final foi um conjunto com uma proporção de 1/3, ou seja, um terço das entradas eram da classe minoritária.
Um ponto importante a se mencionar é que apenas o conjunto de treino foi transformado pela Pipeline, e os dados de teste foram deixados de fora, para que a avaliação posterior do modelo seja realizada sobre dados que simulam a realidade, quando o modelo for aplicado para problemas reais.
Conclusão
Dados desbalanceados são uma realidade para uma grande variedade de problemas para o Cientista de Dados, e saber como lidar com um conjunto assim pode ser crucial para o bom desempenho do modelo, quando aplicado para situações reais. Com os tratamentos adequados, estima-se que é possível aumentar em até 30% a performance de um modelo que está treinando com dados desbalanceados¹.
Embora este artigo tenha falado de forma mais ampla sobre os conceitos que giram em torno do assunto, existem muitas pesquisas e publicações sobre formas cada vez mais inovadoras de tratar um conjunto desbalanceado. Um bom ponto de partida é procurar um pouco mais a fundo nas referências que usei como base para esse texto, que podem ser encontradas abaixo.
Agradeço a leitura!
Referências
- Class Imbalance Revisited: A new experimental setup to assess the performance of treatment methods. (Disponível aqui)
- Why do we rely on specific information over statistics? Base Rate Fallacy, explained. (Disponível aqui)
- Recognize Class Imbalance with Baselines and Better Metrics. (Disponível aqui)
- How to Use ROC Curves and Precision-Recall Curves for Classification in Python. (Disponível aqui)
- The Precision-Recall Plot Is More Informative than the ROC Plot When Evaluating Binary Classifiers on Imbalanced Datasets. (Disponível aqui)
- SMOTE: Synthetic Minority Over-sampling Technique. (Disponível aqui)