Data Science
Using Data Science to Predict Viral Tweets
Can you use Machine Learning to predict which tweets will go viral?
In a previous article, I wrote about building an XGBoost model to predict video popularity for a past data science competition. It covered everything from loading the data and libraries to building the model itself. Now, let’s pivot to viral tweets.
Bitgrit released a competition with $3000 💵 up for grabs:
⭐ Viral Tweets Prediction Challenge ⭐
If you’ve ever wondered why tweets go viral, this is the perfect opportunity for you to find the answer using Data Science!
The Goal
Develop a machine learning model to predict the virality level of each tweet based on attributes such as tweet content, media attached to the tweet, and date/time published.
What does the data look like?
📂 Tweets
├──
test_tweets_vectorized_media.csv├──
test_tweets_vectorized_text.csv├──
test_tweets.csv├──
train_tweets_vectorized_media.csv├──
train_tweets_vectorized_text.csv└──
train_tweets.csv
📂 Users├──
user_vectorized_descriptions.csv├──
user_vectorized_profile_images.csv└──
users.csv
The tweets folder contains our test and train data, and the main features are all in tweets.csv
, which has information such as date, number of hashtags, whether it has attachments, and our target variable — virality. The vectorized csv files are vectorized formats of the tweet’s text and media (video or images)
The users folder contains information about the user’s follower and following count, tweet count, whether they’re verified or not, etc. And their profile bio and image are also vectorized in individual csv files.
Relationship between the data
Users data are related through user_id
, whereas tweet data are connected through tweet_id
.
Below is a visualization of the relationship.
More info about the data in the guidelines section of the competition.
Now that you have the goal and some information about the data given to you, it’s time to start doing data science tasks to achieve the goal of predicting tweet virality.
All the code can be found in google collab or on jovian.
Note not all the code will be included in this article for lightweight purposes, so please refer to the notebook for the comprehensive code.
Load Libraries
As with all data science tasks, you start with equipping yourself with the libraries you need for various tasks.
# essentials
import pandas as pd
import numpy as np
# misc libraries
import random
import timeit
import math
import collections
# surpress warnings
import warnings
warnings.filterwarnings('ignore')
# Data Visualization
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style='darkgrid', color_codes=True)
plt.style.use('fivethirtyeight')
%matplotlib inline
# model building
import lightgbm as lgb
from sklearn.feature_selection import SelectFromModel
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
Load Data
You can get the data by registering for the competition here by July 6, 2021.
# Note: this path will be different depending on where you store the dataset
tweets_path = '/content/drive/MyDrive/Colab Notebooks/viral tweets/Dataset/Tweets/'
users_path = '/content/drive/MyDrive/Colab Notebooks/viral tweets/Dataset/Users/'
# Load training datasets
train_tweets = pd.read_csv(tweets_path + 'train_tweets.csv')
train_tweets_vectorized_media = pd.read_csv(tweets_path + 'train_tweets_vectorized_media.csv')
train_tweets_vectorized_text = pd.read_csv(tweets_path + 'train_tweets_vectorized_text.csv')
# Load test dataset
test_tweets = pd.read_csv(tweets_path + 'test_tweets.csv')
test_tweets_vectorized_media = pd.read_csv(tweets_path + 'test_tweets_vectorized_media.csv')
test_tweets_vectorized_text = pd.read_csv(tweets_path + 'test_tweets_vectorized_text.csv')
# Load user dataset
users = pd.read_csv(users_path + 'users.csv')
user_vectorized_descriptions = pd.read_csv(users_path + 'user_vectorized_descriptions.csv')
user_vectorized_profile_images = pd.read_csv(users_path + 'user_vectorized_profile_images.csv')
Printing the shape of each of these data:
Dimension of train tweets is (29625, 14)
Dimension of train tweets vectorized media is (21010, 2050)
Dimension of train tweets vectorized text is (29625, 769)
Dimension of test tweets is (12697, 13)
Dimension of test tweets vectorized media is (8946, 2050)
Dimension of test tweets vectorized text is (12697, 769)
Dimension of users is (52, 11)
Dimension of user vectorized descriptions is (52, 769)
Dimension of user vectorized profile images is (52, 2049)
We can tell that a total of 52 users in our dataset have multiple tweets. Notice that vectorized media has fewer rows than tweets, this means that not all tweets have media, and also keep in mind that some tweets have multiple media. This will cause some issues when merging the data but we’ll worry about that later.
Notice also how vectorized media/images all have the same amount of columns/features — 2048, and vectorized text all have 769 features.
Now let’s move on to some exploratory data analysis to understand our data better.
Exploratory Data Analysis
EDA is important to discover trends in your data and determine what transformations and preprocessing are needed to prepare the data for modeling.
How the data looks like
train_tweets.head()
tweet_id | tweet_user_id | tweet_created_at_year | tweet_created_at_month | tweet_created_at_day | tweet_created_at_hour | tweet_hashtag_count | tweet_url_count | tweet_mention_count | tweet_has_attachment | tweet_attachment_class | tweet_language_id | tweet_topic_ids | virality |
34698 | 10 | 2015 | 12 | 5 | 3 | 2 | 1 | 0 | FALSE | C | 0 | [’36’, ’36’, ’36’, ’36’, ’36’, ’36’, ’37’, ’37… | 3 |
24644 | 4 | 2020 | 6 | 19 | 0 | 0 | 1 | 0 | FALSE | C | 0 | [’43’, ’78’, ’79’, ’80’, ’80’, ’89’, ’98’, ’99… | 3 |
36321 | 54 | 2019 | 6 | 2 | 15 | 2 | 3 | 0 | TRUE | A | 0 | [’79’, ’80’, ’98’, ’98’, ’98’, ’99’, ’99’, ’10… | 1 |
2629 | 42 | 2020 | 9 | 6 | 17 | 0 | 1 | 1 | TRUE | A | 0 | [’43’, ’79’, ’80’, ’98’, ’99’, ’99’, ’79’, ’80’] | 2 |
28169 | 32 | 2020 | 11 | 4 | 17 | 2 | 1 | 0 | TRUE | A | 0 | [’79’, ’80’, ’98’, ’99’, ’43’, ’89’] | 2 |
Looking at the data, it looks pretty standard. We have the two primary keys, the features, and the target variable — virality.
But notice how tweet_topic_ids
contains arrays? We’ll have to do some preprocessing, later on, to deal with that.
train_tweets_vectorized_media.head()
media_id | tweet_id | img_feature_0 | img_feature_1 | img_feature_2 | img_feature_3 | img_feature_4 | img_feature_5 | img_feature_6 | img_feature_7 | img_feature_8 | img_feature_9 | img_feature_10 | img_feature_11 | img_feature_12 | img_feature_13 | img_feature_14 | img_feature_15 | img_feature_16 | img_feature_17 | img_feature_18 | img_feature_19 | img_feature_20 | img_feature_21 | img_feature_22 | img_feature_23 | img_feature_24 | img_feature_25 | img_feature_26 | img_feature_27 | img_feature_28 | img_feature_29 | img_feature_30 | img_feature_31 | img_feature_32 | img_feature_33 | img_feature_34 | img_feature_35 | img_feature_36 | img_feature_37 | … | img_feature_2008 | img_feature_2009 | img_feature_2010 | img_feature_2011 | img_feature_2012 | img_feature_2013 | img_feature_2014 | img_feature_2015 | img_feature_2016 | img_feature_2017 | img_feature_2018 | img_feature_2019 | img_feature_2020 | img_feature_2021 | img_feature_2022 | img_feature_2023 | img_feature_2024 | img_feature_2025 | img_feature_2026 | img_feature_2027 | img_feature_2028 | img_feature_2029 | img_feature_2030 | img_feature_2031 | img_feature_2032 | img_feature_2033 | img_feature_2034 | img_feature_2035 | img_feature_2036 | img_feature_2037 | img_feature_2038 | img_feature_2039 | img_feature_2040 | img_feature_2041 | img_feature_2042 | img_feature_2043 | img_feature_2044 | img_feature_2045 | img_feature_2046 | img_feature_2047 |
00001_00000 | 1 | 0.290614 | 0.150803 | 0.008313 | 0.040887 | 0 | 0 | 0.214209 | 0.000792 | 0.00027 | 0.424207 | 0 | 0.372124 | 0.031332 | 0.041069 | 0 | 0.209875 | 0.232068 | 0.003567 | 0.185193 | 0.096734 | 0.127461 | 0.005552 | 0.000435 | 0.429719 | 0.024739 | 0.000051 | 0.000842 | 0.115226 | 0 | 0.026856 | 0 | 0.031253 | 0.032772 | 0.051137 | 0.026097 | 0.962891 | 0.132004 | 0.158875 | … | 0.00048 | 0 | 0.018373 | 0.327092 | 0.079089 | 0.360097 | 0.002562 | 1.116611 | 0.054391 | 0.086378 | 0.045496 | 0.030632 | 0 | 0 | 0 | 0.16522 | 0 | 0.05931 | 0.111803 | 0 | 0.10911 | 0.025834 | 0.235375 | 0.078341 | 0.131708 | 0.013988 | 0 | 0.02596 | 0.01576 | 0.266088 | 0 | 0.24924 | 0.040368 | 0.101314 | 0 | 0.069272 | 0.167507 | 0.044617 | 0.383093 | 0.097627 |
00004_00003 | 4 | 0.038251 | 0.036437 | 0 | 0.015076 | 0 | 0.046953 | 0.64817 | 0.026476 | 0 | 0.191951 | 0.003372 | 0 | 0.009363 | 0 | 0 | 0.012317 | 0.016527 | 0.130308 | 0 | 0 | 0.303246 | 0 | 0.307988 | 0.011478 | 0.044807 | 0.20841 | 0.043399 | 0.118079 | 0.000222 | 0.083115 | 0 | 0.634716 | 0 | 0.01404 | 0.090266 | 0 | 0.127964 | 0.189758 | … | 0 | 0.001796 | 0 | 0.251383 | 0.021052 | 0.802314 | 0.027913 | 0.335493 | 0.017326 | 0 | 0.026515 | 0.056399 | 0.030597 | 0.082174 | 0.003829 | 0.083139 | 0.003266 | 0.249968 | 0.304901 | 0.004793 | 0.028569 | 0 | 0.06998 | 0.007251 | 0.194076 | 0 | 0.124188 | 0 | 0.691953 | 0.009337 | 0.024564 | 0.035555 | 0.369353 | 0 | 0.133307 | 0 | 0 | 0.017894 | 0.816972 | 0.058774 |
00005_00004 | 5 | 0.506981 | 0.305467 | 0.03615 | 0.114539 | 0 | 0.146888 | 0.584753 | 0.157468 | 0 | 0.040884 | 0.009688 | 0 | 0 | 0 | 0.074692 | 0.018211 | 0.233618 | 0.039807 | 0 | 0 | 0.16983 | 0 | 0.045755 | 0.028518 | 0 | 0.259838 | 0.117726 | 0.017544 | 0.018106 | 0 | 0.019969 | 0.257562 | 0 | 0.006338 | 0 | 0 | 0.090963 | 0.746206 | … | 0 | 0 | 0.032314 | 0 | 0 | 0.143607 | 0.371495 | 0.19986 | 0.192786 | 0 | 0.401853 | 0.021595 | 0.033472 | 0.164238 | 0.085964 | 0.293521 | 0 | 0.08259 | 0.020819 | 0.001283 | 0.009182 | 0 | 0.295073 | 0 | 0.300424 | 0.228981 | 0.109332 | 0.032641 | 0.263165 | 0.000992 | 0 | 0.131493 | 0.268107 | 0 | 0.036761 | 0.0086 | 0.018883 | 0.024825 | 0.123289 | 0 |
00008_00007 | 8 | 0 | 0.242857 | 0 | 0.068217 | 0 | 0.117847 | 0 | 0 | 0 | 0.040679 | 0.027965 | 0 | 0.001766 | 0.014541 | 0 | 0.010728 | 0.147126 | 0.574918 | 0 | 0.126482 | 0.096826 | 0.04034 | 0.005732 | 0.000296 | 0 | 0.081848 | 0 | 0.001031 | 0.114108 | 0 | 0 | 0.70256 | 0 | 0.307037 | 0.290887 | 0.003196 | 0.093841 | 0.257387 | … | 0.214097 | 0 | 0.0047 | 0.005008 | 0 | 0.540658 | 0.035425 | 0.069549 | 0.197432 | 0.023377 | 0.01051 | 0 | 0.023421 | 0 | 0.029902 | 0.067589 | 0.080281 | 0.005246 | 0 | 0 | 0.007966 | 0 | 0.064383 | 0.25595 | 0.330105 | 0 | 0.027568 | 0.076803 | 0.126329 | 0.053939 | 0.095629 | 0.221957 | 0.133745 | 0.023491 | 0 | 0 | 0 | 0.065544 | 1.030737 | 0.01037 |
00009_00008 | 9 | 0 | 0.141986 | 0 | 0.000983 | 0 | 0.013148 | 0.066999 | 0.008579 | 0 | 0.132708 | 0.000013 | 0 | 0 | 0 | 0 | 0.001043 | 0.004207 | 0.455128 | 0 | 0.200548 | 0.062265 | 0 | 0.031603 | 0 | 0 | 0.176387 | 0.002721 | 0.014317 | 0.032541 | 0 | 0 | 1.168682 | 0.033444 | 0.214281 | 0.109845 | 0 | 0.069903 | 0.009621 | … | 0.104411 | 0 | 0.039039 | 0.008832 | 0 | 0.524103 | 0.043137 | 0.155169 | 0.004774 | 0 | 0.003791 | 0.048902 | 0 | 0 | 0.049338 | 0.007434 | 0 | 0.000472 | 0 | 0 | 0.020621 | 0 | 0.000286 | 0.318853 | 0.383723 | 0 | 0 | 0.033699 | 0.193585 | 0.071854 | 0 | 0.480097 | 0.361314 | 0.026121 | 0 | 0 | 0 | 0.001085 | 0.653569 | 0.007591 |
5 rows × 2050 columns
train_tweets_vectorized_text.head()
tweet_id | feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | feature_10 | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | feature_20 | feature_21 | feature_22 | feature_23 | feature_24 | feature_25 | feature_26 | feature_27 | feature_28 | feature_29 | feature_30 | feature_31 | feature_32 | feature_33 | feature_34 | feature_35 | feature_36 | feature_37 | feature_38 | … | feature_728 | feature_729 | feature_730 | feature_731 | feature_732 | feature_733 | feature_734 | feature_735 | feature_736 | feature_737 | feature_738 | feature_739 | feature_740 | feature_741 | feature_742 | feature_743 | feature_744 | feature_745 | feature_746 | feature_747 | feature_748 | feature_749 | feature_750 | feature_751 | feature_752 | feature_753 | feature_754 | feature_755 | feature_756 | feature_757 | feature_758 | feature_759 | feature_760 | feature_761 | feature_762 | feature_763 | feature_764 | feature_765 | feature_766 | feature_767 |
0 | 0.125605 | -0.136067 | -0.121691 | -0.160296 | -0.074407 | 0.119014 | -0.343523 | -0.28979 | -0.037007 | 0.120231 | -0.245443 | 0.199461 | -0.154236 | -0.200109 | -0.206436 | 0.270252 | -0.142692 | -0.102078 | 0.157226 | -0.334515 | -0.264958 | -0.112983 | -0.293211 | -0.253694 | -0.104198 | 0.056506 | -0.231244 | 0.152571 | 0.206752 | -0.150545 | 0.112063 | -0.129411 | -0.22415 | -0.17533 | -0.165828 | -0.066047 | -0.159027 | 0.009872 | 0.019299 | … | 0.039782 | -0.174679 | 0.148821 | -0.192575 | -0.114211 | 0.496451 | 0.040274 | -0.14268 | 0.169754 | -0.075535 | -0.117306 | 0.261488 | 0.240786 | -0.15038 | -0.080656 | 0.310319 | 0.042854 | 0.048131 | -0.17271 | 0.135926 | -0.04339 | -0.208796 | 0.040137 | -0.190645 | -0.096934 | -0.009036 | 0.284776 | 0.338148 | -0.440536 | -0.090837 | 0.215511 | -0.330016 | -0.143669 | -0.017097 | 0.211852 | 0.009358 | 0.205395 | -0.100113 | 0.013015 | 0.053247 |
1 | 0.064982 | -0.11685 | 0.034871 | -0.090357 | -0.067459 | 0.030954 | -0.361263 | -0.294617 | -0.077854 | 0.135007 | -0.192705 | 0.252616 | -0.135662 | -0.201412 | -0.183382 | 0.17364 | -0.103182 | -0.074723 | 0.234004 | -0.28356 | -0.120644 | -0.063076 | -0.248546 | -0.224326 | -0.176795 | 0.0614 | -0.243843 | 0.226394 | 0.101096 | -0.077593 | 0.057844 | -0.086949 | -0.23986 | -0.303655 | -0.223538 | -0.041548 | -0.162694 | 0.005842 | 0.053615 | … | 0.151876 | -0.153876 | 0.272216 | -0.265888 | -0.124845 | 0.500886 | -0.053478 | -0.159796 | 0.102271 | 0.032116 | -0.034348 | 0.292187 | 0.236578 | -0.00666 | -0.113676 | 0.249192 | 0.048188 | -0.055551 | -0.037698 | 0.148909 | 0.064823 | -0.27023 | 0.003926 | -0.20708 | -0.062248 | -0.056531 | 0.188629 | 0.366379 | -0.51171 | -0.025049 | 0.193301 | -0.391395 | -0.120417 | -0.072493 | 0.188275 | -0.084694 | 0.152518 | -0.109684 | 0.034304 | 0.018237 |
4 | 0.05116 | -0.076732 | 0.005174 | -0.071699 | -0.204004 | 0.034764 | -0.320014 | -0.231828 | -0.121784 | 0.101362 | -0.238145 | 0.173951 | -0.102029 | -0.181864 | -0.214877 | 0.18611 | -0.032114 | -0.14362 | 0.175421 | -0.260034 | -0.103828 | -0.122353 | -0.31697 | -0.289015 | -0.215771 | 0.027695 | -0.254362 | 0.160985 | 0.040491 | -0.019251 | 0.156431 | -0.089619 | -0.20453 | -0.15548 | -0.207329 | -0.044228 | -0.094432 | -0.054102 | -0.06946 | … | 0.013487 | -0.145385 | 0.251336 | -0.243751 | -0.218401 | 0.55765 | -0.032927 | -0.233999 | 0.107764 | -0.054706 | -0.151404 | 0.268172 | 0.148024 | 0.066601 | -0.126532 | 0.235418 | 0.013908 | 0.107383 | -0.114999 | 0.242328 | 0.06241 | -0.122995 | 0.026454 | -0.118704 | -0.025266 | 0.015129 | 0.252958 | 0.273657 | -0.522295 | -0.049114 | 0.163904 | -0.299526 | -0.099811 | -0.049208 | 0.170104 | -0.125188 | 0.111381 | -0.180564 | -0.061082 | 0.14265 |
5 | 0.127061 | -0.063152 | 0.01052 | 0.000385 | -0.146983 | -0.099981 | -0.382142 | -0.287832 | -0.129653 | 0.056506 | -0.180725 | 0.183492 | -0.055121 | -0.205312 | -0.267817 | 0.152828 | -0.026461 | -0.150457 | 0.253863 | -0.289086 | -0.194721 | -0.073793 | -0.313012 | -0.311311 | -0.254014 | -0.055075 | -0.147885 | 0.179036 | 0.120235 | -0.005531 | 0.080192 | -0.229052 | -0.13706 | -0.281633 | -0.225555 | -0.010621 | -0.10548 | -0.135987 | -0.034342 | … | 0.058002 | -0.146499 | 0.203021 | -0.242434 | -0.169852 | 0.488979 | 0.018102 | -0.17811 | 0.155078 | -0.020591 | -0.14549 | 0.28435 | 0.186599 | 0.037712 | -0.143745 | 0.180907 | -0.023579 | 0.016043 | -0.097705 | 0.168376 | 0.072525 | -0.184797 | 0.048695 | -0.136387 | -0.102923 | -0.042237 | 0.23534 | 0.317434 | -0.532257 | -0.03628 | 0.168049 | -0.355778 | -0.150411 | -0.067176 | 0.216 | -0.076183 | 0.158889 | -0.057773 | -0.079182 | 0.057414 |
8 | 0.144889 | -0.084671 | 0.147057 | -0.06876 | -0.024226 | 0.081921 | -0.362943 | -0.288691 | -0.121206 | 0.145029 | -0.226151 | 0.219446 | -0.06556 | -0.249132 | -0.253098 | 0.139695 | -0.006103 | -0.133933 | 0.246038 | -0.293615 | -0.202569 | -0.105134 | -0.260601 | -0.281008 | -0.16858 | -0.036372 | -0.212415 | 0.052084 | 0.090292 | -0.025743 | 0.087359 | -0.068358 | -0.228364 | -0.236147 | -0.155169 | -0.103995 | -0.17848 | -0.073 | -0.06195 | … | 0.05929 | -0.132187 | 0.255592 | -0.292721 | -0.240281 | 0.562846 | 0.02217 | -0.174992 | 0.161351 | -0.039306 | -0.113506 | 0.25667 | 0.236067 | 0.03883 | -0.130292 | 0.260965 | 0.022115 | 0.017341 | -0.05545 | 0.205454 | -0.037376 | -0.141493 | 0.091427 | -0.116183 | -0.065763 | -0.044039 | 0.178388 | 0.402246 | -0.518461 | -0.054657 | 0.212959 | -0.389882 | -0.157168 | -0.051536 | 0.129986 | -0.021489 | 0.184418 | -0.058603 | -0.023088 | 0.084837 |
5 rows × 769 columns
Our vectorized data looks standard, where each column represents one coordinate in the numeric feature space.
Now let’s analyze the features in our tweets data.
Analysis of Tweet data features
train_tweets.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 29625 entries, 0 to 29624
Data columns (total 14 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 tweet_id 29625 non-null int64
1 tweet_user_id 29625 non-null int64
2 tweet_created_at_year 29625 non-null int64
3 tweet_created_at_month 29625 non-null int64
4 tweet_created_at_day 29625 non-null int64
5 tweet_created_at_hour 29625 non-null int64
6 tweet_hashtag_count 29625 non-null float64
7 tweet_url_count 29625 non-null float64
8 tweet_mention_count 29625 non-null float64
9 tweet_has_attachment 29625 non-null bool
10 tweet_attachment_class 29625 non-null object
11 tweet_language_id 29625 non-null int64
12 tweet_topic_ids 25340 non-null object
13 virality 29625 non-null int64
dtypes: bool(1), float64(3), int64(8), object(2)
memory usage: 3.0+ MB
There’s a total of 11 features in our tweets.
First, we have a look at virality.
train_tweets['virality'].describe()
count 29625.000000
mean 1.907274
std 1.078700
min 1.000000
25% 1.000000
50% 2.000000
75% 2.000000
max 5.000000
Name: virality, dtype: float64
sns.countplot(x = 'virality', data = train_tweets, palette="Set1");
We see that the virality has, on average, a virality level of 2, and it’s mostly level 1.
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
sns.histplot(train_tweets, x = 'tweet_created_at_year', discrete = True, ax = axs[0,0]);
sns.histplot(train_tweets, x = 'tweet_created_at_day', discrete = True, ax = axs[0,1]);
sns.histplot(train_tweets, x = 'tweet_created_at_month', discrete = True, ax = axs[1,0]);
sns.histplot(train_tweets, x = 'tweet_created_at_hour', discrete = True, ax = axs[1,1]);
Looking at the time features, we see that tweets from our data have the highest count in 2020, with most of the tweets being created in December, and on the 27th day. They were also mostly tweeted in the evening (17th hour).
fig, axs = plt.subplots(3, 1, figsize=(12, 10))
sns.histplot(x = 'tweet_hashtag_count', data = train_tweets, discrete = True, ax = axs[0]);
sns.histplot(x = 'tweet_url_count', data = train_tweets, discrete = True, ax = axs[1]);
sns.histplot(x = 'tweet_mention_count', data = train_tweets, discrete = True, ax = axs[2]);
Moving on to hashtag, URL, and mentions count. Most of the tweets have zero hashtags, one URL in their tweets, and zero mentions.
fig, axs = plt.subplots(2, 1, figsize=(10, 7))
sns.countplot(x = 'tweet_attachment_class', data = train_tweets, palette="Set1", ax = axs[0]);
sns.countplot(x = 'tweet_language_id', data = train_tweets, ax = axs[1]);
Plotting the attachment class and language id, we see that class A is the most prevalent, while there’s very little class B. As for language id, a large amount of it is 0, which we can assume is English.
sns.countplot(x = 'tweet_has_attachment', data = train_tweets, palette="Set1");
Most of the tweets also have an attachment of media.
Correlation of features
Now let’s look at how our features are correlated, or in other words, how strong the linear relationship between our features and virality is.
A good way to visualize it is with a heatmap.
corrmat = train_tweets.corr()[2:]
sns.heatmap(corrmat, vmax=.8, square=True);
From our heatmap, we can tell that our features don’t correlate with virality, but some of the features do have some correlation with each other.
df_corr = train_tweets.corr()['virality'][2:-1]
top_features = df_corr.sort_values(ascending=False)
top_features
tweet_language_id 0.030416
tweet_created_at_day 0.017518
tweet_has_attachment 0.005401
tweet_created_at_hour -0.028583
tweet_url_count -0.047833
tweet_created_at_month -0.063757
tweet_mention_count -0.081958
tweet_hashtag_count -0.083262
tweet_created_at_year -0.096487
Name: virality, dtype: float64
Taking the numerical representation of the correlation, we see that the correlation is fairly low, with some of them being negative.
Nonetheless, this doesn’t mean our features are useless as they still have predictive power, all this means is they don’t work “linearly” to predict the virality level.
Now let’s have a look at our users data.
Analysis of Users data
users.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 52 entries, 0 to 51
Data columns (total 11 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 user_id 52 non-null int64
1 user_like_count 52 non-null int64
2 user_followers_count 52 non-null int64
3 user_following_count 52 non-null int64
4 user_listed_on_count 52 non-null int64
5 user_has_location 52 non-null bool
6 user_tweet_count 52 non-null int64
7 user_has_url 52 non-null bool
8 user_verified 52 non-null int64
9 user_created_at_year 52 non-null int64
10 user_created_at_month 52 non-null int64
dtypes: bool(2), int64(9)
memory usage: 3.9 KB
We have a total of 10 features in users.
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
sns.histplot(users, x = 'user_like_count', ax = axs[0,0]);
sns.histplot(users, x = 'user_followers_count', ax = axs[0,1]);
sns.histplot(users, x = 'user_following_count', ax = axs[1,0]);
sns.histplot(users, x = 'user_listed_on_count', ax = axs[1,1]);
Visualizing the counts data, we can observe that the like count is mostly between the range of 0 to 5000 likes, with outliers around 300k.
For users_followers_count
, a large portion of the users are in the 0 to 100k range, and some users have up to 1.6m followers.
For users_following_count
, most of them are in the 0 to 100k range as well, but with a tiny portion being more than 1m.
For the users_listed_on_count
, most of them are between 0 to 5000, with some of users being listed on as many as 400k lists.
fig, axs = plt.subplots(3, 1, figsize=(12, 10))
sns.histplot(users, x = 'user_tweet_count', ax = axs[0]);
sns.histplot(users, x = 'user_created_at_year', discrete = True, ax = axs[1]);
sns.histplot(users, x = 'user_created_at_month', discrete = True, ax = axs[2]);
Our users mostly have tweet counts of around 500k, and most of the accounts were created in 2011, and in the month of August.
fig, axs = plt.subplots(1, 3, figsize=(12, 8))
sns.countplot(x = 'user_has_location', data = users, ax = axs[0], palette="Set1");
sns.countplot(x = 'user_has_url', data = users, ax = axs[1], palette="Set1");
sns.countplot(x = 'user_verified', data = users, ax = axs[2], palette="Set1");
Moving on to the binary data, most of our users have locations listed on their accounts, they have a URL in their bio, and most of them aren’t verified.
Now that we’ve done a bit of EDA on our data, let’s start prepping it for modeling.
Data Preprocessing
These are the few preprocessing tasks we’ll be doing in this particular example. Note that there are definitely various approaches towards doing it, not to mention there are further complex preprocessing and feature engineering you can do, so only treat the approach below as an example and NOT an objective approach.
- dealing with missing values
- one hot encoding
- feature selection using lasso regression
- match number of columns between train and test data (so that the trained model works on test data)
- Join all the features into one final data frame for both train and test/public data.
Dealing with missing data
missing_cols(users)
no missing values
missing_cols(train_tweets)
tweet_topic_ids => 4285 [14.46%]
plt.figure(figsize=(10, 6))
sns.heatmap(train_tweets.isnull(), yticklabels=False, cmap='viridis', cbar=False);
Using a helper function I’ve coded up (code in the notebook), we’ll check the amount of missing data in rows and percentages in users and tweets data
There are some missing data in the topic ids column, so we’ll be dealing with that. We’ll visualize the missing data with a heatmap as well, as it’s useful to search for particular patterns. Our missing data seems to be pretty spread out.
train_tweets.fillna({'tweet_topic_ids':"['0']"}, inplace=True)
missing_cols(train_tweets)
no missing values
To deal with the na
values, we’ll fill them with an array with 0 — [‘0’]
to signify that that particular tweet has no topic ids. The reason I place it in an array is so that later on when we do one-hot encoding, it can be represented as a topic_id of zero. (Again, this is only one approach and there are other ways you can handle this)
Data Cleaning is a essential skill. If you want to brush up on your data cleaning skills, check out our Data Cleaning using Python article.
I also noticed that the counts for hashtag, URL, and mention were a float data type. However, they do not have decimal places at all and they have unique values and can be considered categorical, so I converted them into integers like below.
# train_tweets.tweet_hashtag_count.value_counts()
# train_tweets.tweet_url_count.value_counts()
# train_tweets.tweet_mention_count.value_counts()
# convert floats to ints
cols = ['tweet_hashtag_count', 'tweet_url_count', 'tweet_mention_count']
train_tweets[cols] = train_tweets[cols].applymap(np.int64)
train_tweets[cols].head()
One hot encoding on tweets features
topic_ids = (
train_tweets['tweet_topic_ids'].str.strip('[]').str.split('\s*,\s*').explode()
.str.get_dummies().sum(level=0).add_prefix('topic_id_')
)
topic_ids.rename(columns = lambda x: x.replace("'", ""), inplace=True)
year = pd.get_dummies(train_tweets.tweet_created_at_year, prefix='year')
month = pd.get_dummies(train_tweets.tweet_created_at_month , prefix='month')
day = pd.get_dummies(train_tweets.tweet_created_at_day, prefix='day')
attachment = pd.get_dummies(train_tweets.tweet_attachment_class, prefix='attatchment')
language = pd.get_dummies(train_tweets.tweet_language_id, prefix='language')
## Cyclical encoding
sin_hour = np.sin(2*np.pi*train_tweets['tweet_created_at_hour']/24.0)
sin_hour.name = 'sin_hour'
cos_hour = np.cos(2*np.pi*train_tweets['tweet_created_at_hour']/24.0)
cos_hour.name = 'cos_hour'
For topic_ids, the data is an array so it’s pretty tricky. What you can do is explode the arrays into individual values, and then turn them into dummies.
For the other columns like year, month, attatchment, etc, encoding is pretty easy using get_dummies()
from the pandas library.
As for hour data, since we have to account for temporal features in time data, we’ll have to do cyclical encoding.
After we’ve encoded our features, we can drop the old ones, then join the new encoded features.
columns_drop = [
"tweet_topic_ids",
"tweet_created_at_year",
"tweet_created_at_month",
"tweet_created_at_day",
"tweet_attachment_class",
"tweet_language_id",
"tweet_created_at_hour",
]
dfs = [topic_ids, year, month, day, attachment, language,
sin_hour, cos_hour]
train_tweets_final = train_tweets.drop(columns_drop, 1).join(dfs)
train_tweets_final.head()
tweet_id | tweet_user_id | tweet_hashtag_count | tweet_url_count | tweet_mention_count | tweet_has_attachment | virality | topic_id_0 | topic_id_100 | topic_id_101 | topic_id_104 | topic_id_111 | topic_id_112 | topic_id_118 | topic_id_119 | topic_id_120 | topic_id_121 | topic_id_122 | topic_id_125 | topic_id_126 | topic_id_127 | topic_id_147 | topic_id_148 | topic_id_149 | topic_id_150 | topic_id_151 | topic_id_152 | topic_id_153 | topic_id_155 | topic_id_156 | topic_id_163 | topic_id_165 | topic_id_169 | topic_id_170 | topic_id_171 | topic_id_172 | topic_id_36 | topic_id_37 | topic_id_39 | topic_id_43 | … | day_27 | day_28 | day_29 | day_30 | day_31 | attatchment_A | attatchment_B | attatchment_C | language_0 | language_1 | language_2 | language_3 | language_4 | language_5 | language_6 | language_7 | language_8 | language_9 | language_10 | language_11 | language_12 | language_13 | language_14 | language_15 | language_16 | language_17 | language_18 | language_19 | language_20 | language_21 | language_22 | language_23 | language_24 | language_25 | language_27 | language_28 | language_29 | language_30 | sin_hour | cos_hour |
34698 | 10 | 2 | 1 | 0 | FALSE | 3 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 6 | 6 | 0 | 0 | … | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0.707107 | 0.707107 |
24644 | 4 | 0 | 1 | 0 | FALSE | 3 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | … | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
36321 | 54 | 2 | 3 | 0 | TRUE | 1 | 0 | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | … | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | -0.707107 | -0.707107 |
2629 | 42 | 0 | 1 | 1 | TRUE | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | … | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | -0.965926 | -0.258819 |
28169 | 32 | 2 | 1 | 0 | TRUE | 2 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | … | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | -0.965926 | -0.258819 |
Now our train tweets data has a whopping 151 columns!
One hot encoding on users features
year = pd.get_dummies(users.user_created_at_year, prefix='year')
month = pd.get_dummies(users.user_created_at_month , prefix='month')
user_verified = pd.get_dummies(users.user_verified, prefix='verified')
columns_drop = [
"user_created_at_year",
"user_created_at_month",
"user_verified"
]
dfs = [
year,
month,
user_verified
]
users_final = users.drop(columns_drop, 1).join(dfs)
users_final.head()
user_id | user_like_count | user_followers_count | user_following_count | user_listed_on_count | user_has_location | user_tweet_count | user_has_url | year_2008 | year_2009 | year_2010 | year_2011 | year_2012 | year_2013 | year_2014 | month_1 | month_2 | month_4 | month_5 | month_6 | month_7 | month_8 | month_9 | month_10 | month_11 | month_12 | verified_0 | verified_1 |
0 | 1164 | 48720 | 70469 | 5956 | TRUE | 14122 | TRUE | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
1 | 3914 | 85361 | 2171 | 5943 | FALSE | 6957 | FALSE | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 1 | 0 |
2 | 8292 | 200944 | 1416 | 8379 | TRUE | 83485 | TRUE | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
3 | 1770 | 15385 | 4572 | 1866 | TRUE | 12265 | TRUE | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
4 | 15311 | 459083 | 1021 | 7368 | FALSE | 121193 | FALSE | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 |
We’ll also do one hot encoding on our users features, mainly the year, month and user verified columns. and join it back like what we did with train tweets data.
Feature selection on tweets data
The vectorized data has tons of features, and it’s best if we can “separate the wheat from the chaff”, or choose the best features that can help us predict virality. This can help us save size and speed up our model training time.
We’ll start with our train_tweets_vectorized_media
data. Since not all tweets have media, we’ll have to append the virality of each tweet onto the ones with media. This can be done by doing a right join, on tweet_id. This way all the virality for each tweet can be matched with the tweet image features.
Since we don’t want the columns in train_tweets besides virality to merge with our vectorized data, we can use the .difference
function to only get virality.
# create new data frame that matches row number between train tweets and vectorized media
vectorized_media_df = pd.merge(train_tweets,train_tweets_vectorized_media, on ='tweet_id', how = 'right')
vectorized_media_df.drop(train_tweets.columns.difference(['virality']), axis=1, inplace=True)
vectorized_media_df.head()
virality | media_id | img_feature_0 | img_feature_1 | img_feature_2 | img_feature_3 | img_feature_4 | img_feature_5 | img_feature_6 | img_feature_7 | img_feature_8 | img_feature_9 | img_feature_10 | img_feature_11 | img_feature_12 | img_feature_13 | img_feature_14 | img_feature_15 | img_feature_16 | img_feature_17 | img_feature_18 | img_feature_19 | img_feature_20 | img_feature_21 | img_feature_22 | img_feature_23 | img_feature_24 | img_feature_25 | img_feature_26 | img_feature_27 | img_feature_28 | img_feature_29 | img_feature_30 | img_feature_31 | img_feature_32 | img_feature_33 | img_feature_34 | img_feature_35 | img_feature_36 | img_feature_37 | … | img_feature_2008 | img_feature_2009 | img_feature_2010 | img_feature_2011 | img_feature_2012 | img_feature_2013 | img_feature_2014 | img_feature_2015 | img_feature_2016 | img_feature_2017 | img_feature_2018 | img_feature_2019 | img_feature_2020 | img_feature_2021 | img_feature_2022 | img_feature_2023 | img_feature_2024 | img_feature_2025 | img_feature_2026 | img_feature_2027 | img_feature_2028 | img_feature_2029 | img_feature_2030 | img_feature_2031 | img_feature_2032 | img_feature_2033 | img_feature_2034 | img_feature_2035 | img_feature_2036 | img_feature_2037 | img_feature_2038 | img_feature_2039 | img_feature_2040 | img_feature_2041 | img_feature_2042 | img_feature_2043 | img_feature_2044 | img_feature_2045 | img_feature_2046 | img_feature_2047 |
1 | 00001_00000 | 0.290614 | 0.150803 | 0.008313 | 0.040887 | 0 | 0 | 0.214209 | 0.000792 | 0.00027 | 0.424207 | 0 | 0.372124 | 0.031332 | 0.041069 | 0 | 0.209875 | 0.232068 | 0.003567 | 0.185193 | 0.096734 | 0.127461 | 0.005552 | 0.000435 | 0.429719 | 0.024739 | 0.000051 | 0.000842 | 0.115226 | 0 | 0.026856 | 0 | 0.031253 | 0.032772 | 0.051137 | 0.026097 | 0.962891 | 0.132004 | 0.158875 | … | 0.00048 | 0 | 0.018373 | 0.327092 | 0.079089 | 0.360097 | 0.002562 | 1.116611 | 0.054391 | 0.086378 | 0.045496 | 0.030632 | 0 | 0 | 0 | 0.16522 | 0 | 0.05931 | 0.111803 | 0 | 0.10911 | 0.025834 | 0.235375 | 0.078341 | 0.131708 | 0.013988 | 0 | 0.02596 | 0.01576 | 0.266088 | 0 | 0.24924 | 0.040368 | 0.101314 | 0 | 0.069272 | 0.167507 | 0.044617 | 0.383093 | 0.097627 |
2 | 00004_00003 | 0.038251 | 0.036437 | 0 | 0.015076 | 0 | 0.046953 | 0.64817 | 0.026476 | 0 | 0.191951 | 0.003372 | 0 | 0.009363 | 0 | 0 | 0.012317 | 0.016527 | 0.130308 | 0 | 0 | 0.303246 | 0 | 0.307988 | 0.011478 | 0.044807 | 0.20841 | 0.043399 | 0.118079 | 0.000222 | 0.083115 | 0 | 0.634716 | 0 | 0.01404 | 0.090266 | 0 | 0.127964 | 0.189758 | … | 0 | 0.001796 | 0 | 0.251383 | 0.021052 | 0.802314 | 0.027913 | 0.335493 | 0.017326 | 0 | 0.026515 | 0.056399 | 0.030597 | 0.082174 | 0.003829 | 0.083139 | 0.003266 | 0.249968 | 0.304901 | 0.004793 | 0.028569 | 0 | 0.06998 | 0.007251 | 0.194076 | 0 | 0.124188 | 0 | 0.691953 | 0.009337 | 0.024564 | 0.035555 | 0.369353 | 0 | 0.133307 | 0 | 0 | 0.017894 | 0.816972 | 0.058774 |
1 | 00005_00004 | 0.506981 | 0.305467 | 0.03615 | 0.114539 | 0 | 0.146888 | 0.584753 | 0.157468 | 0 | 0.040884 | 0.009688 | 0 | 0 | 0 | 0.074692 | 0.018211 | 0.233618 | 0.039807 | 0 | 0 | 0.16983 | 0 | 0.045755 | 0.028518 | 0 | 0.259838 | 0.117726 | 0.017544 | 0.018106 | 0 | 0.019969 | 0.257562 | 0 | 0.006338 | 0 | 0 | 0.090963 | 0.746206 | … | 0 | 0 | 0.032314 | 0 | 0 | 0.143607 | 0.371495 | 0.19986 | 0.192786 | 0 | 0.401853 | 0.021595 | 0.033472 | 0.164238 | 0.085964 | 0.293521 | 0 | 0.08259 | 0.020819 | 0.001283 | 0.009182 | 0 | 0.295073 | 0 | 0.300424 | 0.228981 | 0.109332 | 0.032641 | 0.263165 | 0.000992 | 0 | 0.131493 | 0.268107 | 0 | 0.036761 | 0.0086 | 0.018883 | 0.024825 | 0.123289 | 0 |
1 | 00008_00007 | 0 | 0.242857 | 0 | 0.068217 | 0 | 0.117847 | 0 | 0 | 0 | 0.040679 | 0.027965 | 0 | 0.001766 | 0.014541 | 0 | 0.010728 | 0.147126 | 0.574918 | 0 | 0.126482 | 0.096826 | 0.04034 | 0.005732 | 0.000296 | 0 | 0.081848 | 0 | 0.001031 | 0.114108 | 0 | 0 | 0.70256 | 0 | 0.307037 | 0.290887 | 0.003196 | 0.093841 | 0.257387 | … | 0.214097 | 0 | 0.0047 | 0.005008 | 0 | 0.540658 | 0.035425 | 0.069549 | 0.197432 | 0.023377 | 0.01051 | 0 | 0.023421 | 0 | 0.029902 | 0.067589 | 0.080281 | 0.005246 | 0 | 0 | 0.007966 | 0 | 0.064383 | 0.25595 | 0.330105 | 0 | 0.027568 | 0.076803 | 0.126329 | 0.053939 | 0.095629 | 0.221957 | 0.133745 | 0.023491 | 0 | 0 | 0 | 0.065544 | 1.030737 | 0.01037 |
1 | 00009_00008 | 0 | 0.141986 | 0 | 0.000983 | 0 | 0.013148 | 0.066999 | 0.008579 | 0 | 0.132708 | 0.000013 | 0 | 0 | 0 | 0 | 0.001043 | 0.004207 | 0.455128 | 0 | 0.200548 | 0.062265 | 0 | 0.031603 | 0 | 0 | 0.176387 | 0.002721 | 0.014317 | 0.032541 | 0 | 0 | 1.168682 | 0.033444 | 0.214281 | 0.109845 | 0 | 0.069903 | 0.009621 | … | 0.104411 | 0 | 0.039039 | 0.008832 | 0 | 0.524103 | 0.043137 | 0.155169 | 0.004774 | 0 | 0.003791 | 0.048902 | 0 | 0 | 0.049338 | 0.007434 | 0 | 0.000472 | 0 | 0 | 0.020621 | 0 | 0.000286 | 0.318853 | 0.383723 | 0 | 0 | 0.033699 | 0.193585 | 0.071854 | 0 | 0.480097 | 0.361314 | 0.026121 | 0 | 0 | 0 | 0.001085 | 0.653569 | 0.007591 |
5 rows × 2050 columns
We can then take this dataset, and start doing feature selection.
# Set the target as well as dependent variables from image data.
y = vectorized_media_df['virality']
x = vectorized_media_df.loc[:, vectorized_media_df.columns.str.contains("img_")]
# Run Lasso regression for feature selection.
sel_model = SelectFromModel(LogisticRegression(C=1, penalty='l1', solver='liblinear'))
# time the model fitting
start = timeit.default_timer()
# Fit the trained model on our data
sel_model.fit(x, y)
stop = timeit.default_timer()
print('Time: ', stop - start)
# get index of good features
sel_index = sel_model.get_support()
# count the no of columns selected
counter = collections.Counter(sel_model.get_support())
counter
Time: 113.86132761099998
Counter({False: 2, True: 2046})
Our target is virality, and the features are the columns that contain “img_”
Using the collections library, we can count the number of features that were chosen by the model. In this case, only two features were deemed not “worthy” so it wasn’t that useful.
Nonetheless, with indexes from our model, we can concatenate them with our media and tweet id back and form the final train_tweets_media
data with two columns lesser than before.
media_ind_df = pd.DataFrame(x[x.columns[(sel_index)]])
train_tweets_media_final = pd.concat([train_tweets_vectorized_media[['media_id', 'tweet_id']], media_ind_df], axis=1)
train_tweets_media_final.head()
media_id | tweet_id | img_feature_0 | img_feature_1 | img_feature_2 | img_feature_3 | img_feature_4 | img_feature_5 | img_feature_6 | img_feature_7 | img_feature_8 | img_feature_9 | img_feature_10 | img_feature_11 | img_feature_12 | img_feature_13 | img_feature_14 | img_feature_15 | img_feature_16 | img_feature_17 | img_feature_18 | img_feature_19 | img_feature_20 | img_feature_21 | img_feature_22 | img_feature_23 | img_feature_24 | img_feature_25 | img_feature_26 | img_feature_27 | img_feature_28 | img_feature_29 | img_feature_30 | img_feature_31 | img_feature_32 | img_feature_33 | img_feature_34 | img_feature_35 | img_feature_36 | img_feature_37 | … | img_feature_2008 | img_feature_2009 | img_feature_2010 | img_feature_2011 | img_feature_2012 | img_feature_2013 | img_feature_2014 | img_feature_2015 | img_feature_2016 | img_feature_2017 | img_feature_2018 | img_feature_2019 | img_feature_2020 | img_feature_2021 | img_feature_2022 | img_feature_2023 | img_feature_2024 | img_feature_2025 | img_feature_2026 | img_feature_2027 | img_feature_2028 | img_feature_2029 | img_feature_2030 | img_feature_2031 | img_feature_2032 | img_feature_2033 | img_feature_2034 | img_feature_2035 | img_feature_2036 | img_feature_2037 | img_feature_2038 | img_feature_2039 | img_feature_2040 | img_feature_2041 | img_feature_2042 | img_feature_2043 | img_feature_2044 | img_feature_2045 | img_feature_2046 | img_feature_2047 |
00001_00000 | 1 | 0.290614 | 0.150803 | 0.008313 | 0.040887 | 0 | 0 | 0.214209 | 0.000792 | 0.00027 | 0.424207 | 0 | 0.372124 | 0.031332 | 0.041069 | 0 | 0.209875 | 0.232068 | 0.003567 | 0.185193 | 0.096734 | 0.127461 | 0.005552 | 0.000435 | 0.429719 | 0.024739 | 0.000051 | 0.000842 | 0.115226 | 0 | 0.026856 | 0 | 0.031253 | 0.032772 | 0.051137 | 0.026097 | 0.962891 | 0.132004 | 0.158875 | … | 0.00048 | 0 | 0.018373 | 0.327092 | 0.079089 | 0.360097 | 0.002562 | 1.116611 | 0.054391 | 0.086378 | 0.045496 | 0.030632 | 0 | 0 | 0 | 0.16522 | 0 | 0.05931 | 0.111803 | 0 | 0.10911 | 0.025834 | 0.235375 | 0.078341 | 0.131708 | 0.013988 | 0 | 0.02596 | 0.01576 | 0.266088 | 0 | 0.24924 | 0.040368 | 0.101314 | 0 | 0.069272 | 0.167507 | 0.044617 | 0.383093 | 0.097627 |
00004_00003 | 4 | 0.038251 | 0.036437 | 0 | 0.015076 | 0 | 0.046953 | 0.64817 | 0.026476 | 0 | 0.191951 | 0.003372 | 0 | 0.009363 | 0 | 0 | 0.012317 | 0.016527 | 0.130308 | 0 | 0 | 0.303246 | 0 | 0.307988 | 0.011478 | 0.044807 | 0.20841 | 0.043399 | 0.118079 | 0.000222 | 0.083115 | 0 | 0.634716 | 0 | 0.01404 | 0.090266 | 0 | 0.127964 | 0.189758 | … | 0 | 0.001796 | 0 | 0.251383 | 0.021052 | 0.802314 | 0.027913 | 0.335493 | 0.017326 | 0 | 0.026515 | 0.056399 | 0.030597 | 0.082174 | 0.003829 | 0.083139 | 0.003266 | 0.249968 | 0.304901 | 0.004793 | 0.028569 | 0 | 0.06998 | 0.007251 | 0.194076 | 0 | 0.124188 | 0 | 0.691953 | 0.009337 | 0.024564 | 0.035555 | 0.369353 | 0 | 0.133307 | 0 | 0 | 0.017894 | 0.816972 | 0.058774 |
00005_00004 | 5 | 0.506981 | 0.305467 | 0.03615 | 0.114539 | 0 | 0.146888 | 0.584753 | 0.157468 | 0 | 0.040884 | 0.009688 | 0 | 0 | 0 | 0.074692 | 0.018211 | 0.233618 | 0.039807 | 0 | 0 | 0.16983 | 0 | 0.045755 | 0.028518 | 0 | 0.259838 | 0.117726 | 0.017544 | 0.018106 | 0 | 0.019969 | 0.257562 | 0 | 0.006338 | 0 | 0 | 0.090963 | 0.746206 | … | 0 | 0 | 0.032314 | 0 | 0 | 0.143607 | 0.371495 | 0.19986 | 0.192786 | 0 | 0.401853 | 0.021595 | 0.033472 | 0.164238 | 0.085964 | 0.293521 | 0 | 0.08259 | 0.020819 | 0.001283 | 0.009182 | 0 | 0.295073 | 0 | 0.300424 | 0.228981 | 0.109332 | 0.032641 | 0.263165 | 0.000992 | 0 | 0.131493 | 0.268107 | 0 | 0.036761 | 0.0086 | 0.018883 | 0.024825 | 0.123289 | 0 |
00008_00007 | 8 | 0 | 0.242857 | 0 | 0.068217 | 0 | 0.117847 | 0 | 0 | 0 | 0.040679 | 0.027965 | 0 | 0.001766 | 0.014541 | 0 | 0.010728 | 0.147126 | 0.574918 | 0 | 0.126482 | 0.096826 | 0.04034 | 0.005732 | 0.000296 | 0 | 0.081848 | 0 | 0.001031 | 0.114108 | 0 | 0 | 0.70256 | 0 | 0.307037 | 0.290887 | 0.003196 | 0.093841 | 0.257387 | … | 0.214097 | 0 | 0.0047 | 0.005008 | 0 | 0.540658 | 0.035425 | 0.069549 | 0.197432 | 0.023377 | 0.01051 | 0 | 0.023421 | 0 | 0.029902 | 0.067589 | 0.080281 | 0.005246 | 0 | 0 | 0.007966 | 0 | 0.064383 | 0.25595 | 0.330105 | 0 | 0.027568 | 0.076803 | 0.126329 | 0.053939 | 0.095629 | 0.221957 | 0.133745 | 0.023491 | 0 | 0 | 0 | 0.065544 | 1.030737 | 0.01037 |
00009_00008 | 9 | 0 | 0.141986 | 0 | 0.000983 | 0 | 0.013148 | 0.066999 | 0.008579 | 0 | 0.132708 | 0.000013 | 0 | 0 | 0 | 0 | 0.001043 | 0.004207 | 0.455128 | 0 | 0.200548 | 0.062265 | 0 | 0.031603 | 0 | 0 | 0.176387 | 0.002721 | 0.014317 | 0.032541 | 0 | 0 | 1.168682 | 0.033444 | 0.214281 | 0.109845 | 0 | 0.069903 | 0.009621 | … | 0.104411 | 0 | 0.039039 | 0.008832 | 0 | 0.524103 | 0.043137 | 0.155169 | 0.004774 | 0 | 0.003791 | 0.048902 | 0 | 0 | 0.049338 | 0.007434 | 0 | 0.000472 | 0 | 0 | 0.020621 | 0 | 0.000286 | 0.318853 | 0.383723 | 0 | 0 | 0.033699 | 0.193585 | 0.071854 | 0 | 0.480097 | 0.361314 | 0.026121 | 0 | 0 | 0 | 0.001085 | 0.653569 | 0.007591 |
For train_tweets_vectorized_text
, it’s the same case as above. We’ll have to do a right join to match the virality of each tweet and then perform feature selection.
Feature selection on users data
To perform feature selection on users data, we would need to append virality level to each user. But since each user has multiple tweets, one way to approach this is to get the median of their virality, and then perform a right join to match the median virality level to each user. Then with this new data frame, we can perform feature selection.
Using groupby
and agg
from pandas, we can find the median and then merge them.
# Find the median of virality for each user to reduce features for user vectorized
# description and profile
average_virality_df =train_tweets.groupby('tweet_user_id').agg(pd.Series.median)['virality']
descriptions_df = pd.merge(average_virality_df, user_vectorized_descriptions, left_on ='tweet_user_id', right_on = 'user_id', how = 'right')
profile_images_df = pd.merge(average_virality_df, user_vectorized_profile_images, left_on ='tweet_user_id', right_on = 'user_id', how = 'right')
descriptions_df.head()
virality | user_id | feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | feature_10 | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | feature_20 | feature_21 | feature_22 | feature_23 | feature_24 | feature_25 | feature_26 | feature_27 | feature_28 | feature_29 | feature_30 | feature_31 | feature_32 | feature_33 | feature_34 | feature_35 | feature_36 | feature_37 | … | feature_728 | feature_729 | feature_730 | feature_731 | feature_732 | feature_733 | feature_734 | feature_735 | feature_736 | feature_737 | feature_738 | feature_739 | feature_740 | feature_741 | feature_742 | feature_743 | feature_744 | feature_745 | feature_746 | feature_747 | feature_748 | feature_749 | feature_750 | feature_751 | feature_752 | feature_753 | feature_754 | feature_755 | feature_756 | feature_757 | feature_758 | feature_759 | feature_760 | feature_761 | feature_762 | feature_763 | feature_764 | feature_765 | feature_766 | feature_767 |
1 | 0 | 0.132536 | -0.137393 | -0.064037 | -0.118342 | -0.130279 | 0.048067 | -0.421301 | -0.313038 | 0.047779 | 0.041972 | -0.2115 | 0.157389 | -0.119609 | -0.167288 | -0.183701 | 0.1626 | -0.118144 | -0.160549 | 0.20617 | -0.349808 | -0.180516 | -0.075424 | -0.228215 | -0.227588 | -0.20613 | 0.097065 | -0.20194 | 0.113164 | 0.115008 | -0.025116 | 0.0634 | -0.129166 | -0.154574 | -0.219841 | -0.18545 | -0.099904 | -0.084291 | -0.044961 | … | 0.083985 | -0.157461 | 0.285158 | -0.197924 | -0.163785 | 0.535255 | 0.027747 | -0.155363 | 0.146396 | -0.090979 | -0.170317 | 0.254166 | 0.260563 | -0.071186 | -0.140582 | 0.310176 | 0.083907 | -0.034472 | -0.17274 | 0.126395 | -0.004203 | -0.177539 | 0.038244 | -0.18842 | -0.080583 | 0.065391 | 0.265358 | 0.307018 | -0.494297 | -0.14292 | 0.238264 | -0.315408 | -0.159851 | -0.00384 | 0.213492 | 0.002498 | 0.177574 | -0.136515 | -0.012882 | 0.017399 |
2 | 1 | 0.107849 | -0.168418 | 0.027251 | -0.075079 | -0.084762 | 0.076149 | -0.390708 | -0.271934 | 0.007423 | 0.030401 | -0.216736 | 0.183259 | -0.069264 | -0.236452 | -0.209206 | 0.174043 | -0.121529 | -0.150529 | 0.228872 | -0.336505 | -0.204807 | -0.152244 | -0.307261 | -0.216196 | -0.265559 | 0.077822 | -0.34644 | 0.154961 | 0.165459 | -0.000246 | 0.065532 | -0.173314 | -0.191337 | -0.143802 | -0.223451 | -0.06728 | -0.124719 | -0.16018 | … | 0.036054 | -0.140715 | 0.224058 | -0.174127 | -0.15951 | 0.531637 | -0.003619 | -0.117995 | 0.093102 | -0.086952 | -0.189147 | 0.209478 | 0.246669 | -0.04345 | -0.158822 | 0.295335 | 0.058998 | -0.008168 | -0.144616 | 0.219429 | 0.049639 | -0.211484 | 0.026302 | -0.199768 | -0.131321 | 0.020595 | 0.30496 | 0.283139 | -0.525245 | -0.187449 | 0.232922 | -0.314534 | -0.177011 | -0.04171 | 0.209785 | -0.023427 | 0.158203 | -0.143221 | 0.030484 | 0.081693 |
1 | 2 | 0.122312 | -0.159376 | -0.073417 | -0.149442 | -0.122684 | -0.005277 | -0.351233 | -0.297342 | -0.00601 | 0.083945 | -0.243968 | 0.184267 | -0.045257 | -0.191175 | -0.168322 | 0.190007 | -0.150225 | -0.191811 | 0.260278 | -0.32333 | -0.226146 | -0.106863 | -0.163877 | -0.207189 | -0.153667 | 0.09043 | -0.265063 | 0.103507 | 0.147642 | -0.003167 | 0.083048 | -0.220785 | -0.242494 | -0.238759 | -0.19413 | -0.034603 | 0.002399 | -0.173476 | … | 0.003421 | -0.165013 | 0.254066 | -0.213777 | -0.134803 | 0.554688 | -0.02458 | -0.159201 | 0.116502 | -0.111342 | -0.140976 | 0.216088 | 0.219368 | -0.052936 | -0.126136 | 0.385574 | 0.03982 | -0.023451 | -0.120135 | 0.191185 | 0.016503 | -0.223201 | 0.051937 | -0.162366 | -0.111131 | 0.047493 | 0.281597 | 0.339442 | -0.440569 | -0.059945 | 0.173621 | -0.292476 | -0.185078 | -0.026784 | 0.184902 | 0.009539 | 0.217004 | -0.091951 | 0.025304 | 0.058501 |
1 | 3 | 0.160509 | -0.137915 | -0.002524 | -0.034696 | 0.028126 | 0.056299 | -0.365196 | -0.259523 | -0.037929 | 0.104135 | -0.206807 | 0.194023 | -0.105497 | -0.277824 | -0.154094 | 0.185838 | -0.147508 | -0.18359 | 0.282249 | -0.251785 | -0.132236 | -0.15296 | -0.293629 | -0.165441 | -0.207462 | 0.033447 | -0.275356 | 0.135713 | 0.106392 | -0.023706 | 0.049851 | -0.122355 | -0.158445 | -0.189165 | -0.210765 | 0.043706 | -0.079914 | -0.075443 | … | 0.034914 | -0.239188 | 0.287563 | -0.316668 | -0.133856 | 0.503008 | -0.039165 | -0.090734 | 0.055525 | -0.082432 | -0.095903 | 0.214028 | 0.232646 | -0.098938 | -0.117663 | 0.290847 | -0.0232 | -0.031113 | -0.156534 | 0.187953 | 0.039809 | -0.193605 | 0.044424 | -0.119147 | 0.001465 | -0.09017 | 0.228475 | 0.299477 | -0.412852 | -0.191728 | 0.205752 | -0.300688 | -0.133753 | 0.002206 | 0.245214 | -0.056659 | 0.152064 | -0.180211 | 0.022327 | 0.014688 |
3 | 4 | 0.099192 | -0.140809 | -0.012423 | -0.150097 | -0.120169 | 0.054078 | -0.384291 | -0.26965 | -0.046161 | 0.130959 | -0.215248 | 0.158257 | -0.134075 | -0.146719 | -0.225974 | 0.168122 | -0.114289 | -0.130104 | 0.162119 | -0.241793 | -0.211575 | -0.148881 | -0.279008 | -0.237968 | -0.134566 | 0.136716 | -0.251553 | 0.089406 | 0.093279 | 0.038696 | 0.039987 | -0.16418 | -0.115864 | -0.231034 | -0.202994 | -0.039647 | -0.112555 | -0.08796 | … | -0.025008 | -0.133031 | 0.230745 | -0.290248 | -0.105768 | 0.516136 | 0.049697 | -0.117967 | 0.11048 | -0.155596 | -0.143313 | 0.259111 | 0.123462 | -0.100905 | -0.082929 | 0.292229 | 0.064536 | 0.055363 | -0.12801 | 0.258008 | -0.00858 | -0.203917 | 0.010632 | -0.200032 | -0.014674 | 0.069364 | 0.295395 | 0.38916 | -0.432069 | -0.119117 | 0.179307 | -0.367725 | -0.217667 | -0.064391 | 0.163382 | -0.020638 | 0.181554 | -0.200262 | -0.074513 | 0.037301 |
Merging everything together
Now it’s time to merge everything into the final train data frame.
There are two main problems we face:
- Not all tweets have media, and some tweets have multiple media. How do we combine this with our train tweets data frame?
- User vectorized data and our tweets vectorized text data have similar column names, this will cause issues when merging (pandas will append _x and _y to the columns)
One approach is to get the mean of the features based on the tweet id to handle the first issue. We can do that using the groupby
function.
# media final doesn't have equal rows, so I have to group by tweet_id (since
# there are multiple media id for a single tweet), average the features (naive way
# but this is an example to do this) and then I can join it with train_tweets
media_df =train_tweets_media_final.groupby('tweet_id').mean()
Then to solve the feature column names overlapping, we can use the rename function.
# tweets_vectorized_text and user_vectorized_profile_images has same column
# names which will cause problems when merging
# rename columns in tweets_vectorized_text
cols = train_tweets_text_final.columns[train_tweets_text_final.columns.str.contains('feature_')]
train_tweets_text_final.rename(columns = dict(zip(cols, 'text_' + cols)), inplace=True)
train_tweets_text_final.head()
tweet_id | text_feature_0 | text_feature_1 | text_feature_2 | text_feature_3 | text_feature_4 | text_feature_5 | text_feature_6 | text_feature_7 | text_feature_8 | text_feature_9 | text_feature_10 | text_feature_12 | text_feature_13 | text_feature_14 | text_feature_15 | text_feature_16 | text_feature_18 | text_feature_19 | text_feature_21 | text_feature_22 | text_feature_23 | text_feature_25 | text_feature_26 | text_feature_27 | text_feature_28 | text_feature_29 | text_feature_30 | text_feature_31 | text_feature_32 | text_feature_33 | text_feature_34 | text_feature_35 | text_feature_36 | text_feature_37 | text_feature_38 | text_feature_39 | text_feature_40 | text_feature_41 | text_feature_42 | … | text_feature_721 | text_feature_723 | text_feature_724 | text_feature_726 | text_feature_727 | text_feature_728 | text_feature_730 | text_feature_732 | text_feature_733 | text_feature_734 | text_feature_736 | text_feature_737 | text_feature_738 | text_feature_739 | text_feature_740 | text_feature_741 | text_feature_742 | text_feature_744 | text_feature_745 | text_feature_746 | text_feature_747 | text_feature_748 | text_feature_749 | text_feature_750 | text_feature_751 | text_feature_752 | text_feature_753 | text_feature_754 | text_feature_755 | text_feature_756 | text_feature_757 | text_feature_758 | text_feature_759 | text_feature_760 | text_feature_761 | text_feature_763 | text_feature_764 | text_feature_765 | text_feature_766 | text_feature_767 |
0 | 0.125605 | -0.136067 | -0.121691 | -0.160296 | -0.074407 | 0.119014 | -0.343523 | -0.28979 | -0.037007 | 0.120231 | -0.245443 | -0.154236 | -0.200109 | -0.206436 | 0.270252 | -0.142692 | 0.157226 | -0.334515 | -0.112983 | -0.293211 | -0.253694 | 0.056506 | -0.231244 | 0.152571 | 0.206752 | -0.150545 | 0.112063 | -0.129411 | -0.22415 | -0.17533 | -0.165828 | -0.066047 | -0.159027 | 0.009872 | 0.019299 | -0.002647 | 0.342013 | -0.18748 | 0.036805 | … | -0.227042 | 0.083628 | -0.087576 | -0.061539 | -0.193925 | 0.039782 | 0.148821 | -0.114211 | 0.496451 | 0.040274 | 0.169754 | -0.075535 | -0.117306 | 0.261488 | 0.240786 | -0.15038 | -0.080656 | 0.042854 | 0.048131 | -0.17271 | 0.135926 | -0.04339 | -0.208796 | 0.040137 | -0.190645 | -0.096934 | -0.009036 | 0.284776 | 0.338148 | -0.440536 | -0.090837 | 0.215511 | -0.330016 | -0.143669 | -0.017097 | 0.009358 | 0.205395 | -0.100113 | 0.013015 | 0.053247 |
1 | 0.064982 | -0.11685 | 0.034871 | -0.090357 | -0.067459 | 0.030954 | -0.361263 | -0.294617 | -0.077854 | 0.135007 | -0.192705 | -0.135662 | -0.201412 | -0.183382 | 0.17364 | -0.103182 | 0.234004 | -0.28356 | -0.063076 | -0.248546 | -0.224326 | 0.0614 | -0.243843 | 0.226394 | 0.101096 | -0.077593 | 0.057844 | -0.086949 | -0.23986 | -0.303655 | -0.223538 | -0.041548 | -0.162694 | 0.005842 | 0.053615 | -0.017521 | 0.519362 | -0.256842 | 0.026188 | … | -0.192772 | 0.051597 | -0.016395 | 0.054306 | -0.163528 | 0.151876 | 0.272216 | -0.124845 | 0.500886 | -0.053478 | 0.102271 | 0.032116 | -0.034348 | 0.292187 | 0.236578 | -0.00666 | -0.113676 | 0.048188 | -0.055551 | -0.037698 | 0.148909 | 0.064823 | -0.27023 | 0.003926 | -0.20708 | -0.062248 | -0.056531 | 0.188629 | 0.366379 | -0.51171 | -0.025049 | 0.193301 | -0.391395 | -0.120417 | -0.072493 | -0.084694 | 0.152518 | -0.109684 | 0.034304 | 0.018237 |
4 | 0.05116 | -0.076732 | 0.005174 | -0.071699 | -0.204004 | 0.034764 | -0.320014 | -0.231828 | -0.121784 | 0.101362 | -0.238145 | -0.102029 | -0.181864 | -0.214877 | 0.18611 | -0.032114 | 0.175421 | -0.260034 | -0.122353 | -0.31697 | -0.289015 | 0.027695 | -0.254362 | 0.160985 | 0.040491 | -0.019251 | 0.156431 | -0.089619 | -0.20453 | -0.15548 | -0.207329 | -0.044228 | -0.094432 | -0.054102 | -0.06946 | 0.009704 | 0.531497 | -0.277437 | 0.108438 | … | -0.115331 | -0.05998 | 0.028263 | 0.021479 | -0.215399 | 0.013487 | 0.251336 | -0.218401 | 0.55765 | -0.032927 | 0.107764 | -0.054706 | -0.151404 | 0.268172 | 0.148024 | 0.066601 | -0.126532 | 0.013908 | 0.107383 | -0.114999 | 0.242328 | 0.06241 | -0.122995 | 0.026454 | -0.118704 | -0.025266 | 0.015129 | 0.252958 | 0.273657 | -0.522295 | -0.049114 | 0.163904 | -0.299526 | -0.099811 | -0.049208 | -0.125188 | 0.111381 | -0.180564 | -0.061082 | 0.14265 |
5 | 0.127061 | -0.063152 | 0.01052 | 0.000385 | -0.146983 | -0.099981 | -0.382142 | -0.287832 | -0.129653 | 0.056506 | -0.180725 | -0.055121 | -0.205312 | -0.267817 | 0.152828 | -0.026461 | 0.253863 | -0.289086 | -0.073793 | -0.313012 | -0.311311 | -0.055075 | -0.147885 | 0.179036 | 0.120235 | -0.005531 | 0.080192 | -0.229052 | -0.13706 | -0.281633 | -0.225555 | -0.010621 | -0.10548 | -0.135987 | -0.034342 | 0.005922 | 0.500443 | -0.246112 | 0.015544 | … | -0.185265 | -0.001471 | -0.07978 | -0.010137 | -0.12842 | 0.058002 | 0.203021 | -0.169852 | 0.488979 | 0.018102 | 0.155078 | -0.020591 | -0.14549 | 0.28435 | 0.186599 | 0.037712 | -0.143745 | -0.023579 | 0.016043 | -0.097705 | 0.168376 | 0.072525 | -0.184797 | 0.048695 | -0.136387 | -0.102923 | -0.042237 | 0.23534 | 0.317434 | -0.532257 | -0.03628 | 0.168049 | -0.355778 | -0.150411 | -0.067176 | -0.076183 | 0.158889 | -0.057773 | -0.079182 | 0.057414 |
8 | 0.144889 | -0.084671 | 0.147057 | -0.06876 | -0.024226 | 0.081921 | -0.362943 | -0.288691 | -0.121206 | 0.145029 | -0.226151 | -0.06556 | -0.249132 | -0.253098 | 0.139695 | -0.006103 | 0.246038 | -0.293615 | -0.105134 | -0.260601 | -0.281008 | -0.036372 | -0.212415 | 0.052084 | 0.090292 | -0.025743 | 0.087359 | -0.068358 | -0.228364 | -0.236147 | -0.155169 | -0.103995 | -0.17848 | -0.073 | -0.06195 | -0.06206 | 0.535964 | -0.246468 | 0.041863 | … | -0.119538 | 0.00672 | -0.019072 | 0.028261 | -0.182081 | 0.05929 | 0.255592 | -0.240281 | 0.562846 | 0.02217 | 0.161351 | -0.039306 | -0.113506 | 0.25667 | 0.236067 | 0.03883 | -0.130292 | 0.022115 | 0.017341 | -0.05545 | 0.205454 | -0.037376 | -0.141493 | 0.091427 | -0.116183 | -0.065763 | -0.044039 | 0.178388 | 0.402246 | -0.518461 | -0.054657 | 0.212959 | -0.389882 | -0.157168 | -0.051536 | -0.021489 | 0.184418 | -0.058603 | -0.023088 | 0.084837 |
5 rows × 670 columns
Now we can merge all the data frames.
# Merge all tables based on the column 'user_id' for user data, and tweet_id
# for tweet data
# join tweets data
tweet_df = pd.merge(media_df, train_tweets_text_final, on = 'tweet_id', how = 'right')
tweet_df.fillna(0, inplace=True)
# join users data
user_df = pd.merge(users_final, user_profile_images_final, on='user_id')
# join tweets data on train_tweets
tweet_df_final = pd.merge(train_tweets_final, tweet_df, on = 'tweet_id')
# join that with the users data
final_df = pd.merge(tweet_df_final, user_df, left_on = 'tweet_user_id', right_on='user_id')
final_df.shape
(29625, 2946)
There are different ways you can approach this, but my way was basically to first merge the tweets data. There’ll be missing data since not all tweets have media, so fill those NA values with zero.
Next up is to merge the users data. After that, I join the train_tweets
with my tweets data, and then that with users data.
The final shape of our data frame is 2946 columns.
The preprocessing we do in train will have to be replicated on the test data. Without doing that, the model we train using our train data won’t be usable for our test data. (code for test data preprocessing not shown here)
Match Number of Columns for Train and Test
After applying the same pre-processing to our test data, we face another problem. The number of columns of train and test doesn’t match, this happened due to some features missing from either one of them.
To solve this, we can use the set function on the columns, and subtract them off each other to find the difference between them. Then we can add those missing columns to the data and set them to zero.
cols_test = set(test_tweets_final.columns) - set(train_tweets_final.columns)
cols_test # train is missing these 4 columns from test
{'language_26', 'topic_id_117', 'topic_id_123', 'topic_id_38'}
for col in cols_test:
final_df[col] = 0
We also find columns missing in test from our train, but luckily there’s none.
# columns missing in test from train
cols_train = set(train_tweets_final.columns) - set(test_tweets_final.columns)
cols_train.remove('virality') # remove virality from columns to add to test
len(cols_train)
0
Building the Light GBM model
Using LightGBM as compared to XGBoost has a few benefits, mainly faster training speed and higher efficiency, along with being more accurate and using less memory.
We will be building a very simple base model without any parameter tuning. If you want to experiment more with tuning the model, I linked some resources below.
train test split
X = final_df.drop(['virality', 'tweet_user_id', 'tweet_id', 'user_id'], axis=1)
y = final_df['virality']
# Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 0)
print('Training set shape ', X_train.shape)
print('Test set shape ', X_test.shape)
Training set shape (20737, 2946)
Test set shape (8888, 2946)
The columns virality
, tweet_user_id
, tweet_id
, and user_id
aren’t features, so we can drop them. Then, we set virality as our target variable.
Then we split our data into 70% training and 30% test data.
Fit the data to the model
A base LGBM Classifier model is created.
clf = lgb.LGBMClassifier()
clf.fit(X_train, y_train)
LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,
importance_type='split', learning_rate=0.1, max_depth=-1,
min_child_samples=20, min_child_weight=0.001, min_split_gain=0.0,
n_estimators=100, n_jobs=-1, num_leaves=31, objective=None,
random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True,
subsample=1.0, subsample_for_bin=200000, subsample_freq=0)
Fitting x_train
and y_train
, we can train our basic Light GBM model.
Then using that, we can predict it on the test dataset.
We can easily print the accuracy score of our model using the accuracy_score
function from sklearn.metrics
Our base model achieves an accuracy of 66.45%!
We can also plot the features that were important in our model.
# sorted(zip(clf.feature_importances_, X.columns), reverse=True)
feature_imp = pd.DataFrame(sorted(zip(clf.feature_importances_,X.columns)), columns=['Value','Feature'])
plt.figure(figsize=(20, 10))
sns.barplot(x="Value", y="Feature", data=feature_imp.sort_values(by="Value", ascending=False)[:10], color='blue');
Here we see user_follower_count
is most important in predicting virality, which makes a lot of sense, since the more followers you have, the more reach you have with your tweets.
We also see quite a lot of user_vectorized_profile_images
features in the top 10 features, along with user_has_location
and user_like_count
This accuracy can still be improved, but we’ll end it here and fit our model to the public data.
Fit model to test/public data
X = p_final_df.drop(['tweet_user_id', 'tweet_id', 'user_id'], axis=1)
solution = clf.predict(X)
solution_df = pd.concat([p_final_df[['tweet_id']], pd.DataFrame(solution, columns = ['virality'])], axis=1)
solution_df.head()
solution_df.to_csv('solution.csv', index=False)
Voila! We have a simple model that predicts the virality level of tweets.
And with this, you’re ready to publish the submission file and upload it to the Bitgrit competition!
Tips to increase accuracy
Wining a data science competition isn’t easy, a working model isn’t enough. You need a model that has a high amount of accuracy to win over other solutions, not to mention work well on unseen data as well.
A few tips to increase accuracy:
- More data preprocessing — normalizing, scaling, etc.
- Feature Engineering
- Hyper-parameter tuning
- Stacking ensemble ML models
A great way to learn is through good notebooks on Kaggle. Here are a few that might be helpful.
- A Data Science Framework: To Achieve 99% Accuracy
- Full Preprocessing Tutorial
- Introduction to Ensembling/Stacking in Python
- Simple Bayesian Optimization for LightGBM
- LGB + Bayesian parameters finding + Rank average
That’s all for this article, thank you for reading and I hope you learned something new!
Follow bitgrit’s socials 📱 to stay updated on talks and upcoming competitions!