Back to blog

Using Data Science to Predict Viral Tweets

Data Science

Using Data Science to Predict Viral Tweets

Can you use Machine Learning to predict which tweets will go viral?

Photo by Alexander Shatov on Unsplash

In a previous article, I wrote about building an XGBoost model to predict video popularity for a past data science competition. It covered everything from loading the data and libraries to building the model itself. Now, let’s pivot to viral tweets.

Bitgrit released a competition with $3000 💵 up for grabs:

Viral Tweets Prediction Challenge

If you’ve ever wondered why tweets go viral, this is the perfect opportunity for you to find the answer using Data Science!

The Goal

Develop a machine learning model to predict the virality level of each tweet based on attributes such as tweet content, media attached to the tweet, and date/time published.

What does the data look like?

📂 Tweets
├──test_tweets_vectorized_media.csv
├──test_tweets_vectorized_text.csv
├──test_tweets.csv
├──train_tweets_vectorized_media.csv
├──train_tweets_vectorized_text.csv
└──train_tweets.csv
📂 Users
├──user_vectorized_descriptions.csv
├──user_vectorized_profile_images.csv
└──users.csv

The tweets folder contains our test and train data, and the main features are all in tweets.csv, which has information such as date, number of hashtags, whether it has attachments, and our target variable — virality. The vectorized csv files are vectorized formats of the tweet’s text and media (video or images)

The users folder contains information about the user’s follower and following count, tweet count, whether they’re verified or not, etc. And their profile bio and image are also vectorized in individual csv files.

Relationship between the data

Users data are related through user_id, whereas tweet data are connected through tweet_id.

Below is a visualization of the relationship.

Source (Created with drawsql)

More info about the data in the guidelines section of the competition.

Now that you have the goal and some information about the data given to you, it’s time to start doing data science tasks to achieve the goal of predicting tweet virality.

All the code can be found in google collab or on jovian.

Note not all the code will be included in this article for lightweight purposes, so please refer to the notebook for the comprehensive code.

Load Libraries

As with all data science tasks, you start with equipping yourself with the libraries you need for various tasks.

# essentials
import pandas as pd
import numpy as np

# misc libraries
import random
import timeit
import math 
import collections 

# surpress warnings
import warnings
warnings.filterwarnings('ignore')

# Data Visualization
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme(style='darkgrid', color_codes=True)
plt.style.use('fivethirtyeight')
%matplotlib inline

# model building
import lightgbm as lgb
from sklearn.feature_selection import SelectFromModel
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

Load Data

You can get the data by registering for the competition here by July 6, 2021.

# Note: this path will be different depending on where you store the dataset
tweets_path = '/content/drive/MyDrive/Colab Notebooks/viral tweets/Dataset/Tweets/'
users_path = '/content/drive/MyDrive/Colab Notebooks/viral tweets/Dataset/Users/'

# Load training datasets
train_tweets = pd.read_csv(tweets_path + 'train_tweets.csv')
train_tweets_vectorized_media = pd.read_csv(tweets_path + 'train_tweets_vectorized_media.csv')
train_tweets_vectorized_text = pd.read_csv(tweets_path + 'train_tweets_vectorized_text.csv')

# Load test dataset
test_tweets = pd.read_csv(tweets_path + 'test_tweets.csv')
test_tweets_vectorized_media = pd.read_csv(tweets_path + 'test_tweets_vectorized_media.csv')
test_tweets_vectorized_text = pd.read_csv(tweets_path + 'test_tweets_vectorized_text.csv')

# Load user dataset
users = pd.read_csv(users_path + 'users.csv')
user_vectorized_descriptions = pd.read_csv(users_path + 'user_vectorized_descriptions.csv')
user_vectorized_profile_images = pd.read_csv(users_path + 'user_vectorized_profile_images.csv')

Printing the shape of each of these data:

Dimension of train tweets is (29625, 14)
Dimension of train tweets vectorized media is (21010, 2050)
Dimension of train tweets vectorized text is (29625, 769)

Dimension of test tweets is (12697, 13)
Dimension of test tweets vectorized media is (8946, 2050)
Dimension of test tweets vectorized text is (12697, 769)

Dimension of users is (52, 11)
Dimension of user vectorized descriptions is (52, 769)
Dimension of user vectorized profile images is (52, 2049)

We can tell that a total of 52 users in our dataset have multiple tweets. Notice that vectorized media has fewer rows than tweets, this means that not all tweets have media, and also keep in mind that some tweets have multiple media. This will cause some issues when merging the data but we’ll worry about that later.

Notice also how vectorized media/images all have the same amount of columns/features — 2048, and vectorized text all have 769 features.

Now let’s move on to some exploratory data analysis to understand our data better.

Exploratory Data Analysis

EDA is important to discover trends in your data and determine what transformations and preprocessing are needed to prepare the data for modeling.

How the data looks like

train_tweets.head()
tweet_idtweet_user_idtweet_created_at_yeartweet_created_at_monthtweet_created_at_daytweet_created_at_hourtweet_hashtag_counttweet_url_counttweet_mention_counttweet_has_attachmenttweet_attachment_classtweet_language_idtweet_topic_idsvirality
346981020151253210FALSEC0[’36’, ’36’, ’36’, ’36’, ’36’, ’36’, ’37’, ’37…3
24644420206190010FALSEC0[’43’, ’78’, ’79’, ’80’, ’80’, ’89’, ’98’, ’99…3
363215420196215230TRUEA0[’79’, ’80’, ’98’, ’98’, ’98’, ’99’, ’99’, ’10…1
26294220209617011TRUEA0[’43’, ’79’, ’80’, ’98’, ’99’, ’99’, ’79’, ’80’]2
2816932202011417210TRUEA0[’79’, ’80’, ’98’, ’99’, ’43’, ’89’]2

Looking at the data, it looks pretty standard. We have the two primary keys, the features, and the target variable — virality.

But notice how tweet_topic_ids contains arrays? We’ll have to do some preprocessing, later on, to deal with that.

train_tweets_vectorized_media.head()
media_idtweet_idimg_feature_0img_feature_1img_feature_2img_feature_3img_feature_4img_feature_5img_feature_6img_feature_7img_feature_8img_feature_9img_feature_10img_feature_11img_feature_12img_feature_13img_feature_14img_feature_15img_feature_16img_feature_17img_feature_18img_feature_19img_feature_20img_feature_21img_feature_22img_feature_23img_feature_24img_feature_25img_feature_26img_feature_27img_feature_28img_feature_29img_feature_30img_feature_31img_feature_32img_feature_33img_feature_34img_feature_35img_feature_36img_feature_37img_feature_2008img_feature_2009img_feature_2010img_feature_2011img_feature_2012img_feature_2013img_feature_2014img_feature_2015img_feature_2016img_feature_2017img_feature_2018img_feature_2019img_feature_2020img_feature_2021img_feature_2022img_feature_2023img_feature_2024img_feature_2025img_feature_2026img_feature_2027img_feature_2028img_feature_2029img_feature_2030img_feature_2031img_feature_2032img_feature_2033img_feature_2034img_feature_2035img_feature_2036img_feature_2037img_feature_2038img_feature_2039img_feature_2040img_feature_2041img_feature_2042img_feature_2043img_feature_2044img_feature_2045img_feature_2046img_feature_2047
00001_0000010.2906140.1508030.0083130.040887000.2142090.0007920.000270.42420700.3721240.0313320.04106900.2098750.2320680.0035670.1851930.0967340.1274610.0055520.0004350.4297190.0247390.0000510.0008420.11522600.02685600.0312530.0327720.0511370.0260970.9628910.1320040.1588750.0004800.0183730.3270920.0790890.3600970.0025621.1166110.0543910.0863780.0454960.0306320000.1652200.059310.11180300.109110.0258340.2353750.0783410.1317080.01398800.025960.015760.26608800.249240.0403680.10131400.0692720.1675070.0446170.3830930.097627
00004_0000340.0382510.03643700.01507600.0469530.648170.02647600.1919510.00337200.009363000.0123170.0165270.130308000.30324600.3079880.0114780.0448070.208410.0433990.1180790.0002220.08311500.63471600.014040.09026600.1279640.18975800.00179600.2513830.0210520.8023140.0279130.3354930.01732600.0265150.0563990.0305970.0821740.0038290.0831390.0032660.2499680.3049010.0047930.02856900.069980.0072510.19407600.12418800.6919530.0093370.0245640.0355550.36935300.133307000.0178940.8169720.058774
00005_0000450.5069810.3054670.036150.11453900.1468880.5847530.15746800.0408840.0096880000.0746920.0182110.2336180.039807000.1698300.0457550.02851800.2598380.1177260.0175440.01810600.0199690.25756200.006338000.0909630.746206000.032314000.1436070.3714950.199860.19278600.4018530.0215950.0334720.1642380.0859640.29352100.082590.0208190.0012830.00918200.29507300.3004240.2289810.1093320.0326410.2631650.00099200.1314930.26810700.0367610.00860.0188830.0248250.1232890
00008_00007800.24285700.06821700.1178470000.0406790.02796500.0017660.01454100.0107280.1471260.57491800.1264820.0968260.040340.0057320.00029600.08184800.0010310.114108000.7025600.3070370.2908870.0031960.0938410.2573870.21409700.00470.00500800.5406580.0354250.0695490.1974320.0233770.0105100.02342100.0299020.0675890.0802810.005246000.00796600.0643830.255950.33010500.0275680.0768030.1263290.0539390.0956290.2219570.1337450.0234910000.0655441.0307370.01037
00009_00008900.14198600.00098300.0131480.0669990.00857900.1327080.00001300000.0010430.0042070.45512800.2005480.06226500.031603000.1763870.0027210.0143170.032541001.1686820.0334440.2142810.10984500.0699030.0096210.10441100.0390390.00883200.5241030.0431370.1551690.00477400.0037910.048902000.0493380.00743400.000472000.02062100.0002860.3188530.383723000.0336990.1935850.07185400.4800970.3613140.0261210000.0010850.6535690.007591

5 rows × 2050 columns

train_tweets_vectorized_text.head()
tweet_idfeature_0feature_1feature_2feature_3feature_4feature_5feature_6feature_7feature_8feature_9feature_10feature_11feature_12feature_13feature_14feature_15feature_16feature_17feature_18feature_19feature_20feature_21feature_22feature_23feature_24feature_25feature_26feature_27feature_28feature_29feature_30feature_31feature_32feature_33feature_34feature_35feature_36feature_37feature_38feature_728feature_729feature_730feature_731feature_732feature_733feature_734feature_735feature_736feature_737feature_738feature_739feature_740feature_741feature_742feature_743feature_744feature_745feature_746feature_747feature_748feature_749feature_750feature_751feature_752feature_753feature_754feature_755feature_756feature_757feature_758feature_759feature_760feature_761feature_762feature_763feature_764feature_765feature_766feature_767
00.125605-0.136067-0.121691-0.160296-0.0744070.119014-0.343523-0.28979-0.0370070.120231-0.2454430.199461-0.154236-0.200109-0.2064360.270252-0.142692-0.1020780.157226-0.334515-0.264958-0.112983-0.293211-0.253694-0.1041980.056506-0.2312440.1525710.206752-0.1505450.112063-0.129411-0.22415-0.17533-0.165828-0.066047-0.1590270.0098720.0192990.039782-0.1746790.148821-0.192575-0.1142110.4964510.040274-0.142680.169754-0.075535-0.1173060.2614880.240786-0.15038-0.0806560.3103190.0428540.048131-0.172710.135926-0.04339-0.2087960.040137-0.190645-0.096934-0.0090360.2847760.338148-0.440536-0.0908370.215511-0.330016-0.143669-0.0170970.2118520.0093580.205395-0.1001130.0130150.053247
10.064982-0.116850.034871-0.090357-0.0674590.030954-0.361263-0.294617-0.0778540.135007-0.1927050.252616-0.135662-0.201412-0.1833820.17364-0.103182-0.0747230.234004-0.28356-0.120644-0.063076-0.248546-0.224326-0.1767950.0614-0.2438430.2263940.101096-0.0775930.057844-0.086949-0.23986-0.303655-0.223538-0.041548-0.1626940.0058420.0536150.151876-0.1538760.272216-0.265888-0.1248450.500886-0.053478-0.1597960.1022710.032116-0.0343480.2921870.236578-0.00666-0.1136760.2491920.048188-0.055551-0.0376980.1489090.064823-0.270230.003926-0.20708-0.062248-0.0565310.1886290.366379-0.51171-0.0250490.193301-0.391395-0.120417-0.0724930.188275-0.0846940.152518-0.1096840.0343040.018237
40.05116-0.0767320.005174-0.071699-0.2040040.034764-0.320014-0.231828-0.1217840.101362-0.2381450.173951-0.102029-0.181864-0.2148770.18611-0.032114-0.143620.175421-0.260034-0.103828-0.122353-0.31697-0.289015-0.2157710.027695-0.2543620.1609850.040491-0.0192510.156431-0.089619-0.20453-0.15548-0.207329-0.044228-0.094432-0.054102-0.069460.013487-0.1453850.251336-0.243751-0.2184010.55765-0.032927-0.2339990.107764-0.054706-0.1514040.2681720.1480240.066601-0.1265320.2354180.0139080.107383-0.1149990.2423280.06241-0.1229950.026454-0.118704-0.0252660.0151290.2529580.273657-0.522295-0.0491140.163904-0.299526-0.099811-0.0492080.170104-0.1251880.111381-0.180564-0.0610820.14265
50.127061-0.0631520.010520.000385-0.146983-0.099981-0.382142-0.287832-0.1296530.056506-0.1807250.183492-0.055121-0.205312-0.2678170.152828-0.026461-0.1504570.253863-0.289086-0.194721-0.073793-0.313012-0.311311-0.254014-0.055075-0.1478850.1790360.120235-0.0055310.080192-0.229052-0.13706-0.281633-0.225555-0.010621-0.10548-0.135987-0.0343420.058002-0.1464990.203021-0.242434-0.1698520.4889790.018102-0.178110.155078-0.020591-0.145490.284350.1865990.037712-0.1437450.180907-0.0235790.016043-0.0977050.1683760.072525-0.1847970.048695-0.136387-0.102923-0.0422370.235340.317434-0.532257-0.036280.168049-0.355778-0.150411-0.0671760.216-0.0761830.158889-0.057773-0.0791820.057414
80.144889-0.0846710.147057-0.06876-0.0242260.081921-0.362943-0.288691-0.1212060.145029-0.2261510.219446-0.06556-0.249132-0.2530980.139695-0.006103-0.1339330.246038-0.293615-0.202569-0.105134-0.260601-0.281008-0.16858-0.036372-0.2124150.0520840.090292-0.0257430.087359-0.068358-0.228364-0.236147-0.155169-0.103995-0.17848-0.073-0.061950.05929-0.1321870.255592-0.292721-0.2402810.5628460.02217-0.1749920.161351-0.039306-0.1135060.256670.2360670.03883-0.1302920.2609650.0221150.017341-0.055450.205454-0.037376-0.1414930.091427-0.116183-0.065763-0.0440390.1783880.402246-0.518461-0.0546570.212959-0.389882-0.157168-0.0515360.129986-0.0214890.184418-0.058603-0.0230880.084837

5 rows × 769 columns

Our vectorized data looks standard, where each column represents one coordinate in the numeric feature space.

Now let’s analyze the features in our tweets data.

Analysis of Tweet data features

train_tweets.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 29625 entries, 0 to 29624
Data columns (total 14 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   tweet_id                29625 non-null  int64  
 1   tweet_user_id           29625 non-null  int64  
 2   tweet_created_at_year   29625 non-null  int64  
 3   tweet_created_at_month  29625 non-null  int64  
 4   tweet_created_at_day    29625 non-null  int64  
 5   tweet_created_at_hour   29625 non-null  int64  
 6   tweet_hashtag_count     29625 non-null  float64
 7   tweet_url_count         29625 non-null  float64
 8   tweet_mention_count     29625 non-null  float64
 9   tweet_has_attachment    29625 non-null  bool   
 10  tweet_attachment_class  29625 non-null  object 
 11  tweet_language_id       29625 non-null  int64  
 12  tweet_topic_ids         25340 non-null  object 
 13  virality                29625 non-null  int64  
dtypes: bool(1), float64(3), int64(8), object(2)
memory usage: 3.0+ MB

There’s a total of 11 features in our tweets.

First, we have a look at virality.

train_tweets['virality'].describe()

count    29625.000000
mean         1.907274
std          1.078700
min          1.000000
25%          1.000000
50%          2.000000
75%          2.000000
max          5.000000
Name: virality, dtype: float64

sns.countplot(x = 'virality', data = train_tweets, palette="Set1");

We see that the virality has, on average, a virality level of 2, and it’s mostly level 1.

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

sns.histplot(train_tweets, x = 'tweet_created_at_year', discrete = True, ax = axs[0,0]);
sns.histplot(train_tweets, x = 'tweet_created_at_day', discrete = True, ax = axs[0,1]);
sns.histplot(train_tweets, x = 'tweet_created_at_month', discrete = True, ax = axs[1,0]);
sns.histplot(train_tweets, x = 'tweet_created_at_hour', discrete = True, ax = axs[1,1]);

Looking at the time features, we see that tweets from our data have the highest count in 2020, with most of the tweets being created in December, and on the 27th day. They were also mostly tweeted in the evening (17th hour).

fig, axs = plt.subplots(3, 1, figsize=(12, 10))

sns.histplot(x = 'tweet_hashtag_count', data = train_tweets, discrete = True, ax = axs[0]);
sns.histplot(x = 'tweet_url_count', data = train_tweets, discrete = True, ax = axs[1]);
sns.histplot(x = 'tweet_mention_count', data = train_tweets, discrete = True, ax = axs[2]);

Moving on to hashtag, URL, and mentions count. Most of the tweets have zero hashtags, one URL in their tweets, and zero mentions.

fig, axs = plt.subplots(2, 1, figsize=(10, 7))

sns.countplot(x = 'tweet_attachment_class', data = train_tweets, palette="Set1", ax = axs[0]);
sns.countplot(x = 'tweet_language_id', data = train_tweets, ax = axs[1]);

Plotting the attachment class and language id, we see that class A is the most prevalent, while there’s very little class B. As for language id, a large amount of it is 0, which we can assume is English.

sns.countplot(x = 'tweet_has_attachment', data = train_tweets, palette="Set1");

Most of the tweets also have an attachment of media.

Correlation of features

Now let’s look at how our features are correlated, or in other words, how strong the linear relationship between our features and virality is.

A good way to visualize it is with a heatmap.

corrmat = train_tweets.corr()[2:] 
sns.heatmap(corrmat, vmax=.8, square=True);

From our heatmap, we can tell that our features don’t correlate with virality, but some of the features do have some correlation with each other.

df_corr = train_tweets.corr()['virality'][2:-1]
top_features = df_corr.sort_values(ascending=False)
top_features

tweet_language_id         0.030416
tweet_created_at_day      0.017518
tweet_has_attachment      0.005401
tweet_created_at_hour    -0.028583
tweet_url_count          -0.047833
tweet_created_at_month   -0.063757
tweet_mention_count      -0.081958
tweet_hashtag_count      -0.083262
tweet_created_at_year    -0.096487
Name: virality, dtype: float64

Taking the numerical representation of the correlation, we see that the correlation is fairly low, with some of them being negative.

Nonetheless, this doesn’t mean our features are useless as they still have predictive power, all this means is they don’t work “linearly” to predict the virality level.

Now let’s have a look at our users data.

Analysis of Users data

users.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 52 entries, 0 to 51
Data columns (total 11 columns):
 #   Column                 Non-Null Count  Dtype
---  ------                 --------------  -----
 0   user_id                52 non-null     int64
 1   user_like_count        52 non-null     int64
 2   user_followers_count   52 non-null     int64
 3   user_following_count   52 non-null     int64
 4   user_listed_on_count   52 non-null     int64
 5   user_has_location      52 non-null     bool 
 6   user_tweet_count       52 non-null     int64
 7   user_has_url           52 non-null     bool 
 8   user_verified          52 non-null     int64
 9   user_created_at_year   52 non-null     int64
 10  user_created_at_month  52 non-null     int64
dtypes: bool(2), int64(9)
memory usage: 3.9 KB

We have a total of 10 features in users.

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

sns.histplot(users, x = 'user_like_count', ax = axs[0,0]);
sns.histplot(users, x = 'user_followers_count', ax = axs[0,1]);
sns.histplot(users, x = 'user_following_count', ax = axs[1,0]);
sns.histplot(users, x = 'user_listed_on_count', ax = axs[1,1]);

Visualizing the counts data, we can observe that the like count is mostly between the range of 0 to 5000 likes, with outliers around 300k.

For users_followers_count, a large portion of the users are in the 0 to 100k range, and some users have up to 1.6m followers.

For users_following_count, most of them are in the 0 to 100k range as well, but with a tiny portion being more than 1m.

For the users_listed_on_count, most of them are between 0 to 5000, with some of users being listed on as many as 400k lists.

fig, axs = plt.subplots(3, 1, figsize=(12, 10))

sns.histplot(users, x = 'user_tweet_count', ax = axs[0]);
sns.histplot(users, x = 'user_created_at_year', discrete = True, ax = axs[1]);
sns.histplot(users, x = 'user_created_at_month', discrete = True, ax = axs[2]);

Our users mostly have tweet counts of around 500k, and most of the accounts were created in 2011, and in the month of August.

fig, axs = plt.subplots(1, 3, figsize=(12, 8))

sns.countplot(x = 'user_has_location', data = users, ax = axs[0], palette="Set1");
sns.countplot(x = 'user_has_url', data = users, ax = axs[1], palette="Set1");
sns.countplot(x = 'user_verified', data = users, ax = axs[2], palette="Set1");

Moving on to the binary data, most of our users have locations listed on their accounts, they have a URL in their bio, and most of them aren’t verified.

Now that we’ve done a bit of EDA on our data, let’s start prepping it for modeling.

Data Preprocessing

These are the few preprocessing tasks we’ll be doing in this particular example. Note that there are definitely various approaches towards doing it, not to mention there are further complex preprocessing and feature engineering you can do, so only treat the approach below as an example and NOT an objective approach.

  • dealing with missing values
  • one hot encoding
  • feature selection using lasso regression
  • match number of columns between train and test data (so that the trained model works on test data)
  • Join all the features into one final data frame for both train and test/public data.

Dealing with missing data

missing_cols(users)

no missing values

missing_cols(train_tweets)

tweet_topic_ids => 4285 [14.46%]

plt.figure(figsize=(10, 6))
sns.heatmap(train_tweets.isnull(), yticklabels=False, cmap='viridis', cbar=False);

Using a helper function I’ve coded up (code in the notebook), we’ll check the amount of missing data in rows and percentages in users and tweets data

There are some missing data in the topic ids column, so we’ll be dealing with that. We’ll visualize the missing data with a heatmap as well, as it’s useful to search for particular patterns. Our missing data seems to be pretty spread out.

train_tweets.fillna({'tweet_topic_ids':"['0']"}, inplace=True)
missing_cols(train_tweets)

no missing values

To deal with the na values, we’ll fill them with an array with 0 — [‘0’] to signify that that particular tweet has no topic ids. The reason I place it in an array is so that later on when we do one-hot encoding, it can be represented as a topic_id of zero. (Again, this is only one approach and there are other ways you can handle this)

Data Cleaning is a essential skill. If you want to brush up on your data cleaning skills, check out our Data Cleaning using Python article.

I also noticed that the counts for hashtag, URL, and mention were a float data type. However, they do not have decimal places at all and they have unique values and can be considered categorical, so I converted them into integers like below.

# train_tweets.tweet_hashtag_count.value_counts()
# train_tweets.tweet_url_count.value_counts()
# train_tweets.tweet_mention_count.value_counts()

# convert floats to ints
cols = ['tweet_hashtag_count', 'tweet_url_count', 'tweet_mention_count']
train_tweets[cols] = train_tweets[cols].applymap(np.int64)
train_tweets[cols].head()

One hot encoding on tweets features

topic_ids = (
    train_tweets['tweet_topic_ids'].str.strip('[]').str.split('\s*,\s*').explode()
    .str.get_dummies().sum(level=0).add_prefix('topic_id_')
) 
topic_ids.rename(columns = lambda x: x.replace("'", ""), inplace=True)

year = pd.get_dummies(train_tweets.tweet_created_at_year, prefix='year')
month = pd.get_dummies(train_tweets.tweet_created_at_month , prefix='month')
day = pd.get_dummies(train_tweets.tweet_created_at_day, prefix='day')
attachment = pd.get_dummies(train_tweets.tweet_attachment_class, prefix='attatchment')
language = pd.get_dummies(train_tweets.tweet_language_id, prefix='language')

## Cyclical encoding
sin_hour = np.sin(2*np.pi*train_tweets['tweet_created_at_hour']/24.0)
sin_hour.name = 'sin_hour'
cos_hour = np.cos(2*np.pi*train_tweets['tweet_created_at_hour']/24.0)
cos_hour.name = 'cos_hour'

For topic_ids, the data is an array so it’s pretty tricky. What you can do is explode the arrays into individual values, and then turn them into dummies.

For the other columns like year, month, attatchment, etc, encoding is pretty easy using get_dummies() from the pandas library.

As for hour data, since we have to account for temporal features in time data, we’ll have to do cyclical encoding.

After we’ve encoded our features, we can drop the old ones, then join the new encoded features.

columns_drop = [
                "tweet_topic_ids",
                "tweet_created_at_year",
                "tweet_created_at_month",
                "tweet_created_at_day",
                "tweet_attachment_class",
                "tweet_language_id",
                "tweet_created_at_hour",
               ]

dfs = [topic_ids, year, month, day, attachment, language, 
       sin_hour, cos_hour]

train_tweets_final = train_tweets.drop(columns_drop, 1).join(dfs)

train_tweets_final.head()
tweet_idtweet_user_idtweet_hashtag_counttweet_url_counttweet_mention_counttweet_has_attachmentviralitytopic_id_0topic_id_100topic_id_101topic_id_104topic_id_111topic_id_112topic_id_118topic_id_119topic_id_120topic_id_121topic_id_122topic_id_125topic_id_126topic_id_127topic_id_147topic_id_148topic_id_149topic_id_150topic_id_151topic_id_152topic_id_153topic_id_155topic_id_156topic_id_163topic_id_165topic_id_169topic_id_170topic_id_171topic_id_172topic_id_36topic_id_37topic_id_39topic_id_43day_27day_28day_29day_30day_31attatchment_Aattatchment_Battatchment_Clanguage_0language_1language_2language_3language_4language_5language_6language_7language_8language_9language_10language_11language_12language_13language_14language_15language_16language_17language_18language_19language_20language_21language_22language_23language_24language_25language_27language_28language_29language_30sin_hourcos_hour
3469810210FALSE3000000000000000000000000000006600000000011000000000000000000000000000000.7071070.707107
246444010FALSE30100000000000000000000000000000010000000110000000000000000000000000000001
3632154230TRUE102000000000000000000000000000000000000100100000000000000000000000000000-0.707107-0.707107
262942011TRUE200000000000000000000000000000000100000100100000000000000000000000000000-0.965926-0.258819
2816932210TRUE200000000000000000000000000000000100000100100000000000000000000000000000-0.965926-0.258819

Now our train tweets data has a whopping 151 columns!

One hot encoding on users features

year = pd.get_dummies(users.user_created_at_year, prefix='year')
month = pd.get_dummies(users.user_created_at_month , prefix='month')
user_verified = pd.get_dummies(users.user_verified, prefix='verified')

columns_drop = [
                "user_created_at_year",
                "user_created_at_month",
                "user_verified"
              ]

dfs = [
        year,
        month,
        user_verified
      ]

users_final = users.drop(columns_drop, 1).join(dfs)

users_final.head()
user_iduser_like_countuser_followers_countuser_following_countuser_listed_on_countuser_has_locationuser_tweet_countuser_has_urlyear_2008year_2009year_2010year_2011year_2012year_2013year_2014month_1month_2month_4month_5month_6month_7month_8month_9month_10month_11month_12verified_0verified_1
0116448720704695956TRUE14122TRUE00100000000001000010
139148536121715943FALSE6957FALSE00010000000000100010
2829220094414168379TRUE83485TRUE00010000000100000010
317701538545721866TRUE12265TRUE00010000000001000010
41531145908310217368FALSE121193FALSE00010000000100000010

We’ll also do one hot encoding on our users features, mainly the year, month and user verified columns. and join it back like what we did with train tweets data.

Feature selection on tweets data

The vectorized data has tons of features, and it’s best if we can “separate the wheat from the chaff”, or choose the best features that can help us predict virality. This can help us save size and speed up our model training time.

We’ll start with our train_tweets_vectorized_media data. Since not all tweets have media, we’ll have to append the virality of each tweet onto the ones with media. This can be done by doing a right join, on tweet_id. This way all the virality for each tweet can be matched with the tweet image features.

Since we don’t want the columns in train_tweets besides virality to merge with our vectorized data, we can use the .difference function to only get virality.

# create new data frame that matches row number between train tweets and vectorized media
vectorized_media_df = pd.merge(train_tweets,train_tweets_vectorized_media, on ='tweet_id', how = 'right')
vectorized_media_df.drop(train_tweets.columns.difference(['virality']), axis=1, inplace=True)
vectorized_media_df.head()
viralitymedia_idimg_feature_0img_feature_1img_feature_2img_feature_3img_feature_4img_feature_5img_feature_6img_feature_7img_feature_8img_feature_9img_feature_10img_feature_11img_feature_12img_feature_13img_feature_14img_feature_15img_feature_16img_feature_17img_feature_18img_feature_19img_feature_20img_feature_21img_feature_22img_feature_23img_feature_24img_feature_25img_feature_26img_feature_27img_feature_28img_feature_29img_feature_30img_feature_31img_feature_32img_feature_33img_feature_34img_feature_35img_feature_36img_feature_37img_feature_2008img_feature_2009img_feature_2010img_feature_2011img_feature_2012img_feature_2013img_feature_2014img_feature_2015img_feature_2016img_feature_2017img_feature_2018img_feature_2019img_feature_2020img_feature_2021img_feature_2022img_feature_2023img_feature_2024img_feature_2025img_feature_2026img_feature_2027img_feature_2028img_feature_2029img_feature_2030img_feature_2031img_feature_2032img_feature_2033img_feature_2034img_feature_2035img_feature_2036img_feature_2037img_feature_2038img_feature_2039img_feature_2040img_feature_2041img_feature_2042img_feature_2043img_feature_2044img_feature_2045img_feature_2046img_feature_2047
100001_000000.2906140.1508030.0083130.040887000.2142090.0007920.000270.42420700.3721240.0313320.04106900.2098750.2320680.0035670.1851930.0967340.1274610.0055520.0004350.4297190.0247390.0000510.0008420.11522600.02685600.0312530.0327720.0511370.0260970.9628910.1320040.1588750.0004800.0183730.3270920.0790890.3600970.0025621.1166110.0543910.0863780.0454960.0306320000.1652200.059310.11180300.109110.0258340.2353750.0783410.1317080.01398800.025960.015760.26608800.249240.0403680.10131400.0692720.1675070.0446170.3830930.097627
200004_000030.0382510.03643700.01507600.0469530.648170.02647600.1919510.00337200.009363000.0123170.0165270.130308000.30324600.3079880.0114780.0448070.208410.0433990.1180790.0002220.08311500.63471600.014040.09026600.1279640.18975800.00179600.2513830.0210520.8023140.0279130.3354930.01732600.0265150.0563990.0305970.0821740.0038290.0831390.0032660.2499680.3049010.0047930.02856900.069980.0072510.19407600.12418800.6919530.0093370.0245640.0355550.36935300.133307000.0178940.8169720.058774
100005_000040.5069810.3054670.036150.11453900.1468880.5847530.15746800.0408840.0096880000.0746920.0182110.2336180.039807000.1698300.0457550.02851800.2598380.1177260.0175440.01810600.0199690.25756200.006338000.0909630.746206000.032314000.1436070.3714950.199860.19278600.4018530.0215950.0334720.1642380.0859640.29352100.082590.0208190.0012830.00918200.29507300.3004240.2289810.1093320.0326410.2631650.00099200.1314930.26810700.0367610.00860.0188830.0248250.1232890
100008_0000700.24285700.06821700.1178470000.0406790.02796500.0017660.01454100.0107280.1471260.57491800.1264820.0968260.040340.0057320.00029600.08184800.0010310.114108000.7025600.3070370.2908870.0031960.0938410.2573870.21409700.00470.00500800.5406580.0354250.0695490.1974320.0233770.0105100.02342100.0299020.0675890.0802810.005246000.00796600.0643830.255950.33010500.0275680.0768030.1263290.0539390.0956290.2219570.1337450.0234910000.0655441.0307370.01037
100009_0000800.14198600.00098300.0131480.0669990.00857900.1327080.00001300000.0010430.0042070.45512800.2005480.06226500.031603000.1763870.0027210.0143170.032541001.1686820.0334440.2142810.10984500.0699030.0096210.10441100.0390390.00883200.5241030.0431370.1551690.00477400.0037910.048902000.0493380.00743400.000472000.02062100.0002860.3188530.383723000.0336990.1935850.07185400.4800970.3613140.0261210000.0010850.6535690.007591

5 rows × 2050 columns

We can then take this dataset, and start doing feature selection.

# Set the target as well as dependent variables from image data.
y = vectorized_media_df['virality']
x = vectorized_media_df.loc[:, vectorized_media_df.columns.str.contains("img_")] 

# Run Lasso regression for feature selection.
sel_model = SelectFromModel(LogisticRegression(C=1, penalty='l1', solver='liblinear'))

# time the model fitting
start = timeit.default_timer()

# Fit the trained model on our data
sel_model.fit(x, y)

stop = timeit.default_timer()
print('Time: ', stop - start) 

# get index of good features
sel_index = sel_model.get_support()

# count the no of columns selected
counter = collections.Counter(sel_model.get_support())
counter

Time:  113.86132761099998

Counter({False: 2, True: 2046})

Our target is virality, and the features are the columns that contain “img_”

Using the collections library, we can count the number of features that were chosen by the model. In this case, only two features were deemed not “worthy” so it wasn’t that useful.

Nonetheless, with indexes from our model, we can concatenate them with our media and tweet id back and form the final train_tweets_media data with two columns lesser than before.

media_ind_df = pd.DataFrame(x[x.columns[(sel_index)]])
train_tweets_media_final = pd.concat([train_tweets_vectorized_media[['media_id', 'tweet_id']], media_ind_df], axis=1)
train_tweets_media_final.head()
media_idtweet_idimg_feature_0img_feature_1img_feature_2img_feature_3img_feature_4img_feature_5img_feature_6img_feature_7img_feature_8img_feature_9img_feature_10img_feature_11img_feature_12img_feature_13img_feature_14img_feature_15img_feature_16img_feature_17img_feature_18img_feature_19img_feature_20img_feature_21img_feature_22img_feature_23img_feature_24img_feature_25img_feature_26img_feature_27img_feature_28img_feature_29img_feature_30img_feature_31img_feature_32img_feature_33img_feature_34img_feature_35img_feature_36img_feature_37img_feature_2008img_feature_2009img_feature_2010img_feature_2011img_feature_2012img_feature_2013img_feature_2014img_feature_2015img_feature_2016img_feature_2017img_feature_2018img_feature_2019img_feature_2020img_feature_2021img_feature_2022img_feature_2023img_feature_2024img_feature_2025img_feature_2026img_feature_2027img_feature_2028img_feature_2029img_feature_2030img_feature_2031img_feature_2032img_feature_2033img_feature_2034img_feature_2035img_feature_2036img_feature_2037img_feature_2038img_feature_2039img_feature_2040img_feature_2041img_feature_2042img_feature_2043img_feature_2044img_feature_2045img_feature_2046img_feature_2047
00001_0000010.2906140.1508030.0083130.040887000.2142090.0007920.000270.42420700.3721240.0313320.04106900.2098750.2320680.0035670.1851930.0967340.1274610.0055520.0004350.4297190.0247390.0000510.0008420.11522600.02685600.0312530.0327720.0511370.0260970.9628910.1320040.1588750.0004800.0183730.3270920.0790890.3600970.0025621.1166110.0543910.0863780.0454960.0306320000.1652200.059310.11180300.109110.0258340.2353750.0783410.1317080.01398800.025960.015760.26608800.249240.0403680.10131400.0692720.1675070.0446170.3830930.097627
00004_0000340.0382510.03643700.01507600.0469530.648170.02647600.1919510.00337200.009363000.0123170.0165270.130308000.30324600.3079880.0114780.0448070.208410.0433990.1180790.0002220.08311500.63471600.014040.09026600.1279640.18975800.00179600.2513830.0210520.8023140.0279130.3354930.01732600.0265150.0563990.0305970.0821740.0038290.0831390.0032660.2499680.3049010.0047930.02856900.069980.0072510.19407600.12418800.6919530.0093370.0245640.0355550.36935300.133307000.0178940.8169720.058774
00005_0000450.5069810.3054670.036150.11453900.1468880.5847530.15746800.0408840.0096880000.0746920.0182110.2336180.039807000.1698300.0457550.02851800.2598380.1177260.0175440.01810600.0199690.25756200.006338000.0909630.746206000.032314000.1436070.3714950.199860.19278600.4018530.0215950.0334720.1642380.0859640.29352100.082590.0208190.0012830.00918200.29507300.3004240.2289810.1093320.0326410.2631650.00099200.1314930.26810700.0367610.00860.0188830.0248250.1232890
00008_00007800.24285700.06821700.1178470000.0406790.02796500.0017660.01454100.0107280.1471260.57491800.1264820.0968260.040340.0057320.00029600.08184800.0010310.114108000.7025600.3070370.2908870.0031960.0938410.2573870.21409700.00470.00500800.5406580.0354250.0695490.1974320.0233770.0105100.02342100.0299020.0675890.0802810.005246000.00796600.0643830.255950.33010500.0275680.0768030.1263290.0539390.0956290.2219570.1337450.0234910000.0655441.0307370.01037
00009_00008900.14198600.00098300.0131480.0669990.00857900.1327080.00001300000.0010430.0042070.45512800.2005480.06226500.031603000.1763870.0027210.0143170.032541001.1686820.0334440.2142810.10984500.0699030.0096210.10441100.0390390.00883200.5241030.0431370.1551690.00477400.0037910.048902000.0493380.00743400.000472000.02062100.0002860.3188530.383723000.0336990.1935850.07185400.4800970.3613140.0261210000.0010850.6535690.007591

For train_tweets_vectorized_text, it’s the same case as above. We’ll have to do a right join to match the virality of each tweet and then perform feature selection.

Feature selection on users data

To perform feature selection on users data, we would need to append virality level to each user. But since each user has multiple tweets, one way to approach this is to get the median of their virality, and then perform a right join to match the median virality level to each user. Then with this new data frame, we can perform feature selection.

Using groupby and agg from pandas, we can find the median and then merge them.

# Find the median of virality for each user to reduce features for user vectorized
# description and profile
average_virality_df =train_tweets.groupby('tweet_user_id').agg(pd.Series.median)['virality']

descriptions_df = pd.merge(average_virality_df, user_vectorized_descriptions, left_on ='tweet_user_id', right_on = 'user_id', how = 'right')
profile_images_df = pd.merge(average_virality_df, user_vectorized_profile_images, left_on ='tweet_user_id', right_on = 'user_id', how = 'right')
descriptions_df.head()
viralityuser_idfeature_0feature_1feature_2feature_3feature_4feature_5feature_6feature_7feature_8feature_9feature_10feature_11feature_12feature_13feature_14feature_15feature_16feature_17feature_18feature_19feature_20feature_21feature_22feature_23feature_24feature_25feature_26feature_27feature_28feature_29feature_30feature_31feature_32feature_33feature_34feature_35feature_36feature_37feature_728feature_729feature_730feature_731feature_732feature_733feature_734feature_735feature_736feature_737feature_738feature_739feature_740feature_741feature_742feature_743feature_744feature_745feature_746feature_747feature_748feature_749feature_750feature_751feature_752feature_753feature_754feature_755feature_756feature_757feature_758feature_759feature_760feature_761feature_762feature_763feature_764feature_765feature_766feature_767
100.132536-0.137393-0.064037-0.118342-0.1302790.048067-0.421301-0.3130380.0477790.041972-0.21150.157389-0.119609-0.167288-0.1837010.1626-0.118144-0.1605490.20617-0.349808-0.180516-0.075424-0.228215-0.227588-0.206130.097065-0.201940.1131640.115008-0.0251160.0634-0.129166-0.154574-0.219841-0.18545-0.099904-0.084291-0.0449610.083985-0.1574610.285158-0.197924-0.1637850.5352550.027747-0.1553630.146396-0.090979-0.1703170.2541660.260563-0.071186-0.1405820.3101760.083907-0.034472-0.172740.126395-0.004203-0.1775390.038244-0.18842-0.0805830.0653910.2653580.307018-0.494297-0.142920.238264-0.315408-0.159851-0.003840.2134920.0024980.177574-0.136515-0.0128820.017399
210.107849-0.1684180.027251-0.075079-0.0847620.076149-0.390708-0.2719340.0074230.030401-0.2167360.183259-0.069264-0.236452-0.2092060.174043-0.121529-0.1505290.228872-0.336505-0.204807-0.152244-0.307261-0.216196-0.2655590.077822-0.346440.1549610.165459-0.0002460.065532-0.173314-0.191337-0.143802-0.223451-0.06728-0.124719-0.160180.036054-0.1407150.224058-0.174127-0.159510.531637-0.003619-0.1179950.093102-0.086952-0.1891470.2094780.246669-0.04345-0.1588220.2953350.058998-0.008168-0.1446160.2194290.049639-0.2114840.026302-0.199768-0.1313210.0205950.304960.283139-0.525245-0.1874490.232922-0.314534-0.177011-0.041710.209785-0.0234270.158203-0.1432210.0304840.081693
120.122312-0.159376-0.073417-0.149442-0.122684-0.005277-0.351233-0.297342-0.006010.083945-0.2439680.184267-0.045257-0.191175-0.1683220.190007-0.150225-0.1918110.260278-0.32333-0.226146-0.106863-0.163877-0.207189-0.1536670.09043-0.2650630.1035070.147642-0.0031670.083048-0.220785-0.242494-0.238759-0.19413-0.0346030.002399-0.1734760.003421-0.1650130.254066-0.213777-0.1348030.554688-0.02458-0.1592010.116502-0.111342-0.1409760.2160880.219368-0.052936-0.1261360.3855740.03982-0.023451-0.1201350.1911850.016503-0.2232010.051937-0.162366-0.1111310.0474930.2815970.339442-0.440569-0.0599450.173621-0.292476-0.185078-0.0267840.1849020.0095390.217004-0.0919510.0253040.058501
130.160509-0.137915-0.002524-0.0346960.0281260.056299-0.365196-0.259523-0.0379290.104135-0.2068070.194023-0.105497-0.277824-0.1540940.185838-0.147508-0.183590.282249-0.251785-0.132236-0.15296-0.293629-0.165441-0.2074620.033447-0.2753560.1357130.106392-0.0237060.049851-0.122355-0.158445-0.189165-0.2107650.043706-0.079914-0.0754430.034914-0.2391880.287563-0.316668-0.1338560.503008-0.039165-0.0907340.055525-0.082432-0.0959030.2140280.232646-0.098938-0.1176630.290847-0.0232-0.031113-0.1565340.1879530.039809-0.1936050.044424-0.1191470.001465-0.090170.2284750.299477-0.412852-0.1917280.205752-0.300688-0.1337530.0022060.245214-0.0566590.152064-0.1802110.0223270.014688
340.099192-0.140809-0.012423-0.150097-0.1201690.054078-0.384291-0.26965-0.0461610.130959-0.2152480.158257-0.134075-0.146719-0.2259740.168122-0.114289-0.1301040.162119-0.241793-0.211575-0.148881-0.279008-0.237968-0.1345660.136716-0.2515530.0894060.0932790.0386960.039987-0.16418-0.115864-0.231034-0.202994-0.039647-0.112555-0.08796-0.025008-0.1330310.230745-0.290248-0.1057680.5161360.049697-0.1179670.11048-0.155596-0.1433130.2591110.123462-0.100905-0.0829290.2922290.0645360.055363-0.128010.258008-0.00858-0.2039170.010632-0.200032-0.0146740.0693640.2953950.38916-0.432069-0.1191170.179307-0.367725-0.217667-0.0643910.163382-0.0206380.181554-0.200262-0.0745130.037301

Merging everything together

Now it’s time to merge everything into the final train data frame.

There are two main problems we face:

  1. Not all tweets have media, and some tweets have multiple media. How do we combine this with our train tweets data frame?
  2. User vectorized data and our tweets vectorized text data have similar column names, this will cause issues when merging (pandas will append _x and _y to the columns)

One approach is to get the mean of the features based on the tweet id to handle the first issue. We can do that using the groupby function.

# media final doesn't have equal rows, so I have to group by tweet_id (since
#  there are multiple media id for a single tweet), average the features (naive way
# but this is an example to do this) and then I can join it with train_tweets
media_df =train_tweets_media_final.groupby('tweet_id').mean()

Then to solve the feature column names overlapping, we can use the rename function.

# tweets_vectorized_text and user_vectorized_profile_images has same column 
# names which will cause problems when merging

# rename columns in tweets_vectorized_text

cols = train_tweets_text_final.columns[train_tweets_text_final.columns.str.contains('feature_')]
train_tweets_text_final.rename(columns = dict(zip(cols, 'text_' + cols)), inplace=True)
train_tweets_text_final.head()
tweet_idtext_feature_0text_feature_1text_feature_2text_feature_3text_feature_4text_feature_5text_feature_6text_feature_7text_feature_8text_feature_9text_feature_10text_feature_12text_feature_13text_feature_14text_feature_15text_feature_16text_feature_18text_feature_19text_feature_21text_feature_22text_feature_23text_feature_25text_feature_26text_feature_27text_feature_28text_feature_29text_feature_30text_feature_31text_feature_32text_feature_33text_feature_34text_feature_35text_feature_36text_feature_37text_feature_38text_feature_39text_feature_40text_feature_41text_feature_42text_feature_721text_feature_723text_feature_724text_feature_726text_feature_727text_feature_728text_feature_730text_feature_732text_feature_733text_feature_734text_feature_736text_feature_737text_feature_738text_feature_739text_feature_740text_feature_741text_feature_742text_feature_744text_feature_745text_feature_746text_feature_747text_feature_748text_feature_749text_feature_750text_feature_751text_feature_752text_feature_753text_feature_754text_feature_755text_feature_756text_feature_757text_feature_758text_feature_759text_feature_760text_feature_761text_feature_763text_feature_764text_feature_765text_feature_766text_feature_767
00.125605-0.136067-0.121691-0.160296-0.0744070.119014-0.343523-0.28979-0.0370070.120231-0.245443-0.154236-0.200109-0.2064360.270252-0.1426920.157226-0.334515-0.112983-0.293211-0.2536940.056506-0.2312440.1525710.206752-0.1505450.112063-0.129411-0.22415-0.17533-0.165828-0.066047-0.1590270.0098720.019299-0.0026470.342013-0.187480.036805-0.2270420.083628-0.087576-0.061539-0.1939250.0397820.148821-0.1142110.4964510.0402740.169754-0.075535-0.1173060.2614880.240786-0.15038-0.0806560.0428540.048131-0.172710.135926-0.04339-0.2087960.040137-0.190645-0.096934-0.0090360.2847760.338148-0.440536-0.0908370.215511-0.330016-0.143669-0.0170970.0093580.205395-0.1001130.0130150.053247
10.064982-0.116850.034871-0.090357-0.0674590.030954-0.361263-0.294617-0.0778540.135007-0.192705-0.135662-0.201412-0.1833820.17364-0.1031820.234004-0.28356-0.063076-0.248546-0.2243260.0614-0.2438430.2263940.101096-0.0775930.057844-0.086949-0.23986-0.303655-0.223538-0.041548-0.1626940.0058420.053615-0.0175210.519362-0.2568420.026188-0.1927720.051597-0.0163950.054306-0.1635280.1518760.272216-0.1248450.500886-0.0534780.1022710.032116-0.0343480.2921870.236578-0.00666-0.1136760.048188-0.055551-0.0376980.1489090.064823-0.270230.003926-0.20708-0.062248-0.0565310.1886290.366379-0.51171-0.0250490.193301-0.391395-0.120417-0.072493-0.0846940.152518-0.1096840.0343040.018237
40.05116-0.0767320.005174-0.071699-0.2040040.034764-0.320014-0.231828-0.1217840.101362-0.238145-0.102029-0.181864-0.2148770.18611-0.0321140.175421-0.260034-0.122353-0.31697-0.2890150.027695-0.2543620.1609850.040491-0.0192510.156431-0.089619-0.20453-0.15548-0.207329-0.044228-0.094432-0.054102-0.069460.0097040.531497-0.2774370.108438-0.115331-0.059980.0282630.021479-0.2153990.0134870.251336-0.2184010.55765-0.0329270.107764-0.054706-0.1514040.2681720.1480240.066601-0.1265320.0139080.107383-0.1149990.2423280.06241-0.1229950.026454-0.118704-0.0252660.0151290.2529580.273657-0.522295-0.0491140.163904-0.299526-0.099811-0.049208-0.1251880.111381-0.180564-0.0610820.14265
50.127061-0.0631520.010520.000385-0.146983-0.099981-0.382142-0.287832-0.1296530.056506-0.180725-0.055121-0.205312-0.2678170.152828-0.0264610.253863-0.289086-0.073793-0.313012-0.311311-0.055075-0.1478850.1790360.120235-0.0055310.080192-0.229052-0.13706-0.281633-0.225555-0.010621-0.10548-0.135987-0.0343420.0059220.500443-0.2461120.015544-0.185265-0.001471-0.07978-0.010137-0.128420.0580020.203021-0.1698520.4889790.0181020.155078-0.020591-0.145490.284350.1865990.037712-0.143745-0.0235790.016043-0.0977050.1683760.072525-0.1847970.048695-0.136387-0.102923-0.0422370.235340.317434-0.532257-0.036280.168049-0.355778-0.150411-0.067176-0.0761830.158889-0.057773-0.0791820.057414
80.144889-0.0846710.147057-0.06876-0.0242260.081921-0.362943-0.288691-0.1212060.145029-0.226151-0.06556-0.249132-0.2530980.139695-0.0061030.246038-0.293615-0.105134-0.260601-0.281008-0.036372-0.2124150.0520840.090292-0.0257430.087359-0.068358-0.228364-0.236147-0.155169-0.103995-0.17848-0.073-0.06195-0.062060.535964-0.2464680.041863-0.1195380.00672-0.0190720.028261-0.1820810.059290.255592-0.2402810.5628460.022170.161351-0.039306-0.1135060.256670.2360670.03883-0.1302920.0221150.017341-0.055450.205454-0.037376-0.1414930.091427-0.116183-0.065763-0.0440390.1783880.402246-0.518461-0.0546570.212959-0.389882-0.157168-0.051536-0.0214890.184418-0.058603-0.0230880.084837

5 rows × 670 columns

Now we can merge all the data frames.

# Merge all tables based on the column 'user_id' for user data, and tweet_id
# for tweet data

# join tweets data
tweet_df = pd.merge(media_df, train_tweets_text_final, on = 'tweet_id', how = 'right')
tweet_df.fillna(0, inplace=True)

# join users data
user_df = pd.merge(users_final, user_profile_images_final, on='user_id')

# join tweets data on train_tweets
tweet_df_final = pd.merge(train_tweets_final, tweet_df, on = 'tweet_id')

# join that with the users data
final_df = pd.merge(tweet_df_final, user_df, left_on = 'tweet_user_id', right_on='user_id')

final_df.shape

(29625, 2946)

There are different ways you can approach this, but my way was basically to first merge the tweets data. There’ll be missing data since not all tweets have media, so fill those NA values with zero.

Next up is to merge the users data. After that, I join the train_tweets with my tweets data, and then that with users data.

The final shape of our data frame is 2946 columns.

The preprocessing we do in train will have to be replicated on the test data. Without doing that, the model we train using our train data won’t be usable for our test data. (code for test data preprocessing not shown here)

Match Number of Columns for Train and Test

After applying the same pre-processing to our test data, we face another problem. The number of columns of train and test doesn’t match, this happened due to some features missing from either one of them.

To solve this, we can use the set function on the columns, and subtract them off each other to find the difference between them. Then we can add those missing columns to the data and set them to zero.

cols_test = set(test_tweets_final.columns) - set(train_tweets_final.columns)
cols_test # train is missing these 4 columns from test

{'language_26', 'topic_id_117', 'topic_id_123', 'topic_id_38'}

for col in cols_test:
  final_df[col] = 0

We also find columns missing in test from our train, but luckily there’s none.

# columns missing in test from train
cols_train = set(train_tweets_final.columns) - set(test_tweets_final.columns)
cols_train.remove('virality') # remove virality from columns to add to test
len(cols_train)

0

Building the Light GBM model

Using LightGBM as compared to XGBoost has a few benefits, mainly faster training speed and higher efficiency, along with being more accurate and using less memory.

We will be building a very simple base model without any parameter tuning. If you want to experiment more with tuning the model, I linked some resources below.

train test split

X = final_df.drop(['virality', 'tweet_user_id', 'tweet_id', 'user_id'], axis=1)
y = final_df['virality']

# Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 0)
print('Training set shape ', X_train.shape)
print('Test set shape ', X_test.shape)

Training set shape  (20737, 2946)
Test set shape  (8888, 2946)

The columns virality, tweet_user_id, tweet_id, and user_idaren’t features, so we can drop them. Then, we set virality as our target variable.

Then we split our data into 70% training and 30% test data.

Fit the data to the model

A base LGBM Classifier model is created.

clf = lgb.LGBMClassifier()
clf.fit(X_train, y_train)

LGBMClassifier(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,
               importance_type='split', learning_rate=0.1, max_depth=-1,
               min_child_samples=20, min_child_weight=0.001, min_split_gain=0.0,
               n_estimators=100, n_jobs=-1, num_leaves=31, objective=None,
               random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True,
               subsample=1.0, subsample_for_bin=200000, subsample_freq=0)

Fitting x_train and y_train , we can train our basic Light GBM model.

Then using that, we can predict it on the test dataset.

We can easily print the accuracy score of our model using the accuracy_score function from sklearn.metrics

Our base model achieves an accuracy of 66.45%!

We can also plot the features that were important in our model.

# sorted(zip(clf.feature_importances_, X.columns), reverse=True)
feature_imp = pd.DataFrame(sorted(zip(clf.feature_importances_,X.columns)), columns=['Value','Feature'])

plt.figure(figsize=(20, 10))
sns.barplot(x="Value", y="Feature", data=feature_imp.sort_values(by="Value", ascending=False)[:10], color='blue');

Here we see user_follower_count is most important in predicting virality, which makes a lot of sense, since the more followers you have, the more reach you have with your tweets.

We also see quite a lot of user_vectorized_profile_images features in the top 10 features, along with user_has_location and user_like_count

This accuracy can still be improved, but we’ll end it here and fit our model to the public data.

Fit model to test/public data

X = p_final_df.drop(['tweet_user_id', 'tweet_id', 'user_id'], axis=1)

solution = clf.predict(X)
solution_df = pd.concat([p_final_df[['tweet_id']], pd.DataFrame(solution, columns = ['virality'])], axis=1)
solution_df.head()
solution_df.to_csv('solution.csv', index=False)

Voila! We have a simple model that predicts the virality level of tweets.

And with this, you’re ready to publish the submission file and upload it to the Bitgrit competition!

Tips to increase accuracy

Wining a data science competition isn’t easy, a working model isn’t enough. You need a model that has a high amount of accuracy to win over other solutions, not to mention work well on unseen data as well.

A few tips to increase accuracy:

  1. More data preprocessing — normalizing, scaling, etc.
  2. Feature Engineering
  3. Hyper-parameter tuning
  4. Stacking ensemble ML models

A great way to learn is through good notebooks on Kaggle. Here are a few that might be helpful.

That’s all for this article, thank you for reading and I hope you learned something new!


Follow bitgrit’s socials 📱 to stay updated on talks and upcoming competitions!