본문 바로가기
Lab & Research/Artificial intelligence

TabNet

by jaeaemin 2024. 8. 4.

 

 

Introduction 


TabNet은 2019년에 Google Research에서 발표한 딥러닝 기반의 트리 구조 모델로, 특히 Tabular Data 형태의 데이터에 최적화된 모델. TabNet은 Gradient-based Learning을 통해 최적화되며, 주요 특징으로 해석 가능성과 효율성을 들 수 있음

 

이름
생년
국어 점수
영어 점수
수학 점수
홍길동
1992년
7월
17일
100점
90점
70점
희동이
1992년
4월
3일
90점
100점
100점

Tabular Data 예시

 

 

Tabular Data과 딥러닝은 다른 데이터 셋 ( Raw형 데이터 셋, 이미지, 영상 .. ) 과 비교하여 효과적인 성능을 보여주지 못하는 추세로, 비정형 데이터를 활용한 딥러닝을 사용하거나 feature engineering에 집중하여 머신 러닝을 사용하는 추세임

 

 

Tabular Data를 활용한 딥러닝 기술이 가져오는 이점 

  • Incremental Learning : streaming data에서 지속적인 학습이 가능함
  • pretraining : 사전 학습으로 학습 시간 단축 및 적은 데이터 사용으로 성능 향상이 가능함
  • Capacity : 복잡하고 다양한 패턴을 수, 데이터가 누적됨에 따라 지속적인 성능 향상을 기대할 수 있음

 

 

 

TabNet 

 

Tree 기반 모델의 변수선택 특징을 네트워크 구조에 반영한 딥러닝 모델 

  • 가공하지 않은 Raw Data에서 Gradient를 기반한 최적화를 사용함으로써 End-To-End 학습을 실현
  • Sequential Attention Mechanism을 사용하고, 각 decision step에서 어떤 feature를 사용할지 선택하여 성능과 해석력을 향상 
  • 특징 선택은 인스턴스마다 다르게 적용되며(instance-wise feature selection), 다른 인스턴스마다 다른 특징이 선택되므로 모델이 데이터의 다양한 특성을 잘 반영할 수 있음
  • 특징의 중요도와 결합 방식을 시각화한 local 해석성, 학습된 모델에 대한 각 특징의 기여도를 정량화한 global 해석성 두 종류의 해석성을 제공함
  • 정형 데이터셋에서 처음으로 unsupervised pre-training을 통해서 성능을 크게 향상시킬 수 있음

 

 

 

(1)TabNet : feature selection

 

 > (관련연구) Conventional DNN 블록으로 Tree와 유사한 형태의 결정 경계를 생성

 

  • sparse instance-wse feature selection learned from data
  • constructs a sequential multi-step architecture, where each step contributes to a portion of te decistion based on the selected features
  • improves the learning capacity via nonlinear processing of the selected features

 

Conventional DNN 블록(왼쪽)과 Tree의 결정 경계(오른쪽) 비교

 

 

 

Feature Selection

  • Sequential attention을 사용하여 각 decision step에서 feature selection을 하며 feedback을 주고 학습하는 구조
  • Input feature에 대해 훈련 가능한 마스크 (Trainable Mask ) 로 Sparse Feature Selection을 수행
  • 특정 Feature 들만 선택하는 것이 아닌, Linear Regreatuon 처럼 각 featue에 가중치를 부여

 

Tree-based Learning

  • Conventional DNN 블록을 사용하여 의사결정 나무와 같이 유사한 형태의 결정 경계를 생성
  • Input feature에 대해 훈련 가능한 마스크 (Trainable Mask)로 Sparse Feature selection을 수

 

 

 

(2) TabNet : Encoding Arch

 

 

 

이전 단계 학습 결과가 다음 단계 Mask 학습에 영향을 주는 연결 구조 

 

  • Sequential Approach : 모델을 반복 연결하여 잔차를 보완하는 gradient boosting이 연상되는 구조
  • Feature Selection: feature transformer와 attentive transformer 블록을 통과하여 최적 mask를 학습함
  • 입력 부분과 Step 1~N으로 구분, 각 단계마다 Feature Transformer와 Attentive transofromer, feature masking으로 구성
  • Feature를 Selection하는 Mask block은 각 Step에서 Feature가 작동하는 것에 대한 insight를 제공
  • Aggregate Block을 통해 궁극적으로는 어떤 feature가 중요하게 작용하는지 확인이 가능함

 

 

Feature Transformer : 선택된 feature로 정확히 예측하기 위한 embedding 기능 ( feature processing )

  • FC -> BN -> GLU로 구성됨
  • 블럭들은 순차적으로 통과하는 구조를 쌓고, 블럭간 residual skop connection을 적용
  • 전체 구조에서 앞 2개의 네트워크 묶음은 모든 dicision step에서 공유되고, 뒤 2개의  block 묶음은 해당 decision step에서만 사용

 

 

- Ghost Batch Norm (BN) 

Batch를 분할한 nano batch 사용으로 잡음 추가 -> 지역 최적화 예방 -> large batch size로 학습 속도 향상

 

 

 

 

 

 

Gated Linerar Unit (GLU)

이전 Layer에서 전달되는 정보의 크기를 제어하는 역할 

 

 

 

 

Attentive Transformer : 변수 선택 기능

  • FC -> BN -> Sparsemax를 거치면서 Mask를 생성하며 Mask는 어떤 feature를 주로 사용할지에 대한 정보를 내포함
  • Attentive Transformer는 현재 의사결정 단계에서 각 변수들이 얼마나 영향을 미쳤는지 사전 정보량(Prior scales)으로 집계
  • Feture transformer 에서 인코딩된 Decision 정보는 Attentive transformer 블록을 거쳐 traninable mask로 변환

 

 

(1) Prior Scales : i번쨰 단계에서 변수의 중복 반영 여부를 경정하는 factor로써, 선택된 변수의 반영률이 점차 낮아지는 특성

(2) Sparsemax : softmax 수 대비 sparsity가 높은 함수로 attention layer 등에서 효과적임

 

(1),(2)를 활용하여 변수의 중복 사용을 제한함으로써, 변수 마다 중요도를 학습할 수 있도록 고안

 

Attentive Transformer

 

 

(3) Entmax : 인자 값에 따라 sparsity를 조절하는것이 가능해 활용도가 높음

 

 

 

 

 

(3) TabNet : Decoding Arch

  • 각 step마다 feature transformer에 fc layer가 연결된 구조
  • 각 step FC Layer 출력값의 aggregaton 값이 입력 features를 복원
  • 일반 학습에서는 Decoder를 사용하지 않지만 Self-supervised 학습 진행시 기존 결측값 보완 및 표현 학습

 

Semi-supervised Learning

  • 앞서 소개한 encoder에 deoder를 연결하면 autoencoder와 같은 자기 학습 구조를 생성할 수 있다.
  • 특정한 영역이 masking된 인코딩 데이터를 원본대로 복원할 수 있도록 학습
  • 사전 학습을 통한 예측 성능 향상, 학습 시간 단축 및 결측치에 대한 보간 효과

 

 

 

 

 

TabNet : interPretation

 

Attentive Transformer의 Mask값을 활용한 변수 중요도 시각화

  • M[i]는 모든 검증 데이터에 대해 각 attentive transformer 단계에서 mask 적용 후 활성화 비율을 표현하며, 지역적인 특성을 확인할 수 있음
  • M_agg는 모든 attentive transformer의 단계의 활성화 비율을 결합한 것으로 글로벌 특성을 확인할 수 있음 

반응형