{ "cells": [ { "cell_type": "markdown", "id": "faab08b8", "metadata": {}, "source": [ "# NSFW text" ] }, { "cell_type": "markdown", "id": "b919917f", "metadata": {}, "source": [ "Builds a TensorFlow model that predicts the probability of a text to be NSFW." ] }, { "cell_type": "code", "execution_count": null, "id": "7ee52623", "metadata": {}, "outputs": [], "source": [ "from datetime import datetime\n", "from functools import reduce\n", "import os\n", "import pandas as pd\n", "import re\n", "from sklearn.metrics import (\n", " average_precision_score,\n", " classification_report, \n", " precision_recall_curve,\n", " PrecisionRecallDisplay,\n", ")\n", "from sklearn.model_selection import train_test_split\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "import re\n", "\n", "from twitter.cuad.representation.models.optimization import create_optimizer #TODO: provide instruction for install\n", "from twitter.cuad.representation.models.text_encoder import TextEncoder #TODO: provide instruction for install\n", "\n", "pd.set_option('display.max_colwidth', None)\n", "pd.set_option('display.expand_frame_repr', False)\n", "\n", "print(tf.__version__)\n", "print(tf.config.list_physical_devices())" ] }, { "cell_type": "markdown", "id": "64a45fd4", "metadata": {}, "source": [ "# 1. Parameters and helper function" ] }, { "cell_type": "code", "execution_count": null, "id": "ce628e8d", "metadata": {}, "outputs": [], "source": [ "log_path = os.path.join('pnsfwtweettext_model_runs', datetime.now().strftime('%Y-%m-%d_%H.%M.%S'))\n", "\n", "tweet_text_feature = 'text'\n", "\n", "params = {\n", " 'batch_size': 32,\n", " 'max_seq_lengths': 256,\n", " 'model_type': 'twitter_bert_base_en_uncased_augmented_mlm',\n", " 'trainable_text_encoder': True,\n", " 'lr': 5e-5,\n", " 'epochs': 10,\n", "}\n", "\n", "REGEX_PATTERNS = [\n", " r'^RT @[A-Za-z0-9_]+: ', \n", " r\"@[A-Za-z0-9_]+\",\n", " r'https:\\/\\/t\\.co\\/[A-Za-z0-9]{10}',\n", " r'@\\?\\?\\?\\?\\?',\n", "]\n", "\n", "EMOJI_PATTERN = re.compile(\n", " \"([\"\n", " \"\\U0001F1E0-\\U0001F1FF\"\n", " \"\\U0001F300-\\U0001F5FF\"\n", " \"\\U0001F600-\\U0001F64F\"\n", " \"\\U0001F680-\\U0001F6FF\"\n", " \"\\U0001F700-\\U0001F77F\"\n", " \"\\U0001F780-\\U0001F7FF\"\n", " \"\\U0001F800-\\U0001F8FF\"\n", " \"\\U0001F900-\\U0001F9FF\"\n", " \"\\U0001FA00-\\U0001FA6F\"\n", " \"\\U0001FA70-\\U0001FAFF\"\n", " \"\\U00002702-\\U000027B0\"\n", " \"])\"\n", " )\n", "\n", "def clean_tweet(text):\n", " for pattern in REGEX_PATTERNS:\n", " text = re.sub(pattern, '', text)\n", "\n", " text = re.sub(EMOJI_PATTERN, r' \\1 ', text)\n", " \n", " text = re.sub(r'\\n', ' ', text)\n", " \n", " return text.strip().lower()" ] }, { "cell_type": "markdown", "id": "b55ad56e", "metadata": {}, "source": [ "# 2. Create data" ] }, { "cell_type": "code", "execution_count": null, "id": "c491ecf7", "metadata": {}, "outputs": [], "source": [ "# Sample data just for demonstration\n", "\n", "df = pd.DataFrame(\n", " {\n", " \"text\": [\n", " \"Vanilla Tweet\",\n", " \"Love is Blind season 4 is 🔥\",\n", " \"I like Apples\",\n", " \"Big dick for you\",\n", " \"Wanna see my puss*?\",\n", " \"Elon Musk is the best\",\n", " ],\n", " \"is_nsfw\": [\n", " 0, 0, 0, 1, 1, 1,\n", " ]\n", " }\n", ")\n", "\n", "df" ] }, { "cell_type": "markdown", "id": "e2369e2f", "metadata": {}, "source": [ "# 3. Train model" ] }, { "cell_type": "markdown", "id": "83a2d177", "metadata": {}, "source": [ "## 3.1. Preprocessing" ] }, { "cell_type": "code", "execution_count": null, "id": "7843012c", "metadata": {}, "outputs": [], "source": [ "df['processed_text'] = df['text'].astype(str).map(clean_tweet)\n", "df.sample(10)" ] }, { "cell_type": "code", "execution_count": null, "id": "e3e5ceb5", "metadata": {}, "outputs": [], "source": [ "X_train, X_val, y_train, y_val = train_test_split(df[['processed_text']], df['is_nsfw'], test_size=0.1, random_state=1)\n", "\n", "def df_to_ds(X, y, shuffle=False):\n", " ds = tf.data.Dataset.from_tensor_slices((\n", " X.values,\n", " tf.one_hot(tf.cast(y.values, tf.int32), depth=2, axis=-1)\n", " ))\n", " \n", " if shuffle:\n", " ds = ds.shuffle(1000, seed=1, reshuffle_each_iteration=True)\n", " \n", " return ds.map(lambda text, label: ({ tweet_text_feature: text }, label)).batch(params['batch_size'])\n", "\n", "ds_train = df_to_ds(X_train, y_train, shuffle=True)\n", "ds_val = df_to_ds(X_val, y_val)\n", "X_train.values\n" ] }, { "cell_type": "markdown", "id": "974c1248", "metadata": {}, "source": [ "## 3.2. Training" ] }, { "cell_type": "code", "execution_count": null, "id": "dfd6fc8f", "metadata": {}, "outputs": [], "source": [ "inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name=tweet_text_feature)\n", "encoder = TextEncoder(\n", " max_seq_lengths=params['max_seq_lengths'],\n", " model_type=params['model_type'],\n", " trainable=params['trainable_text_encoder'],\n", " local_preprocessor_path='demo-preprocessor'\n", ")\n", "embedding = encoder([inputs])[\"pooled_output\"]\n", "predictions = tf.keras.layers.Dense(2, activation='softmax')(embedding)\n", "model = tf.keras.models.Model(inputs=inputs, outputs=predictions)\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "id": "f490f43e", "metadata": {}, "outputs": [], "source": [ "optimizer = create_optimizer(\n", " params['lr'],\n", " params['epochs'] * len(ds_train),\n", " 0,\n", " weight_decay_rate=0.01,\n", " optimizer_type='adamw'\n", ")\n", "bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)\n", "pr_auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=1000, from_logits=False)\n", "model.compile(optimizer=optimizer, loss=bce, metrics=[pr_auc])\n", "\n", "callbacks = [\n", " tf.keras.callbacks.EarlyStopping(\n", " monitor='val_loss',\n", " mode='min',\n", " patience=1,\n", " restore_best_weights=True\n", " ),\n", " tf.keras.callbacks.ModelCheckpoint(\n", " filepath=os.path.join(log_path, 'checkpoints', '{epoch:02d}'),\n", " save_freq='epoch'\n", " ),\n", " tf.keras.callbacks.TensorBoard(\n", " log_dir=os.path.join(log_path, 'scalars'),\n", " update_freq='batch',\n", " write_graph=False\n", " )\n", "]\n", "history = model.fit(\n", " ds_train,\n", " epochs=params['epochs'],\n", " callbacks=callbacks,\n", " validation_data=ds_val,\n", " steps_per_epoch=len(ds_train)\n", ")" ] }, { "cell_type": "markdown", "id": "2779f1e1", "metadata": {}, "source": [ "## 3.3. Check output and performance" ] }, { "cell_type": "code", "execution_count": null, "id": "8a6383db", "metadata": {}, "outputs": [], "source": [ "# Expect prediction to be 1\n", "model.predict([\"xxx 🍑\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "0245eb0d", "metadata": {}, "outputs": [], "source": [ "preds = X_val.processed_text.apply(apply_model)\n", "print(classification_report(y_val, preds >= 0.90, digits=4))\n", "\n", "precision, recall, thresholds = precision_recall_curve(y_val, preds)\n", "\n", "fig = plt.figure(figsize=(15, 10))\n", "plt.plot(precision, recall, lw=2)\n", "plt.grid()\n", "plt.xlim(0.2, 1)\n", "plt.ylim(0.3, 1)\n", "plt.xlabel(\"Recall\", size=20)\n", "plt.ylabel(\"Precision\", size=20)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "0ffb1ba3", "metadata": {}, "outputs": [], "source": [ "average_precision_score(y_val, preds)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.11" } }, "nbformat": 4, "nbformat_minor": 5 }