diff --git a/trust_and_safety_models/nsfw/NSFW_text.ipynb b/trust_and_safety_models/nsfw/NSFW_text.ipynb new file mode 100644 index 000000000..67f8e3886 --- /dev/null +++ b/trust_and_safety_models/nsfw/NSFW_text.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "faab08b8", + "metadata": {}, + "source": [ + "# NSFW text" + ] + }, + { + "cell_type": "markdown", + "id": "b919917f", + "metadata": {}, + "source": [ + "Builds a TensorFlow model that predicts the probability of a text to be NSFW." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ee52623", + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "from functools import reduce\n", + "import os\n", + "import pandas as pd\n", + "import re\n", + "from sklearn.metrics import (\n", + " average_precision_score,\n", + " classification_report, \n", + " precision_recall_curve,\n", + " PrecisionRecallDisplay,\n", + ")\n", + "from sklearn.model_selection import train_test_split\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "import re\n", + "\n", + "from twitter.cuad.representation.models.optimization import create_optimizer #TODO: provide instruction for install\n", + "from twitter.cuad.representation.models.text_encoder import TextEncoder #TODO: provide instruction for install\n", + "\n", + "pd.set_option('display.max_colwidth', None)\n", + "pd.set_option('display.expand_frame_repr', False)\n", + "\n", + "print(tf.__version__)\n", + "print(tf.config.list_physical_devices())" + ] + }, + { + "cell_type": "markdown", + "id": "64a45fd4", + "metadata": {}, + "source": [ + "# 1. Parameters and helper function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ce628e8d", + "metadata": {}, + "outputs": [], + "source": [ + "log_path = os.path.join('pnsfwtweettext_model_runs', datetime.now().strftime('%Y-%m-%d_%H.%M.%S'))\n", + "\n", + "tweet_text_feature = 'text'\n", + "\n", + "params = {\n", + " 'batch_size': 32,\n", + " 'max_seq_lengths': 256,\n", + " 'model_type': 'twitter_bert_base_en_uncased_augmented_mlm',\n", + " 'trainable_text_encoder': True,\n", + " 'lr': 5e-5,\n", + " 'epochs': 10,\n", + "}\n", + "\n", + "REGEX_PATTERNS = [\n", + " r'^RT @[A-Za-z0-9_]+: ', \n", + " r\"@[A-Za-z0-9_]+\",\n", + " r'https:\\/\\/t\\.co\\/[A-Za-z0-9]{10}',\n", + " r'@\\?\\?\\?\\?\\?',\n", + "]\n", + "\n", + "EMOJI_PATTERN = re.compile(\n", + " \"([\"\n", + " \"\\U0001F1E0-\\U0001F1FF\"\n", + " \"\\U0001F300-\\U0001F5FF\"\n", + " \"\\U0001F600-\\U0001F64F\"\n", + " \"\\U0001F680-\\U0001F6FF\"\n", + " \"\\U0001F700-\\U0001F77F\"\n", + " \"\\U0001F780-\\U0001F7FF\"\n", + " \"\\U0001F800-\\U0001F8FF\"\n", + " \"\\U0001F900-\\U0001F9FF\"\n", + " \"\\U0001FA00-\\U0001FA6F\"\n", + " \"\\U0001FA70-\\U0001FAFF\"\n", + " \"\\U00002702-\\U000027B0\"\n", + " \"])\"\n", + " )\n", + "\n", + "def clean_tweet(text):\n", + " for pattern in REGEX_PATTERNS:\n", + " text = re.sub(pattern, '', text)\n", + "\n", + " text = re.sub(EMOJI_PATTERN, r' \\1 ', text)\n", + " \n", + " text = re.sub(r'\\n', ' ', text)\n", + " \n", + " return text.strip().lower()" + ] + }, + { + "cell_type": "markdown", + "id": "b55ad56e", + "metadata": {}, + "source": [ + "# 2. Create data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c491ecf7", + "metadata": {}, + "outputs": [], + "source": [ + "# Sample data just for demonstration\n", + "\n", + "df = pd.DataFrame(\n", + " {\n", + " \"text\": [\n", + " \"Vanilla Tweet\",\n", + " \"Love is Blind season 4 is 🔥\",\n", + " \"I like Apples\",\n", + " \"Big dick for you\",\n", + " \"Wanna see my puss*?\",\n", + " \"Elon Musk is the best\",\n", + " ],\n", + " \"is_nsfw\": [\n", + " 0, 0, 0, 1, 1, 1,\n", + " ]\n", + " }\n", + ")\n", + "\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "e2369e2f", + "metadata": {}, + "source": [ + "# 3. Train model" + ] + }, + { + "cell_type": "markdown", + "id": "83a2d177", + "metadata": {}, + "source": [ + "## 3.1. Preprocessing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7843012c", + "metadata": {}, + "outputs": [], + "source": [ + "df['processed_text'] = df['text'].astype(str).map(clean_tweet)\n", + "df.sample(10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3e5ceb5", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_val, y_train, y_val = train_test_split(df[['processed_text']], df['is_nsfw'], test_size=0.1, random_state=1)\n", + "\n", + "def df_to_ds(X, y, shuffle=False):\n", + " ds = tf.data.Dataset.from_tensor_slices((\n", + " X.values,\n", + " tf.one_hot(tf.cast(y.values, tf.int32), depth=2, axis=-1)\n", + " ))\n", + " \n", + " if shuffle:\n", + " ds = ds.shuffle(1000, seed=1, reshuffle_each_iteration=True)\n", + " \n", + " return ds.map(lambda text, label: ({ tweet_text_feature: text }, label)).batch(params['batch_size'])\n", + "\n", + "ds_train = df_to_ds(X_train, y_train, shuffle=True)\n", + "ds_val = df_to_ds(X_val, y_val)\n", + "X_train.values\n" + ] + }, + { + "cell_type": "markdown", + "id": "974c1248", + "metadata": {}, + "source": [ + "## 3.2. Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfd6fc8f", + "metadata": {}, + "outputs": [], + "source": [ + "inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name=tweet_text_feature)\n", + "encoder = TextEncoder(\n", + " max_seq_lengths=params['max_seq_lengths'],\n", + " model_type=params['model_type'],\n", + " trainable=params['trainable_text_encoder'],\n", + " local_preprocessor_path='demo-preprocessor'\n", + ")\n", + "embedding = encoder([inputs])[\"pooled_output\"]\n", + "predictions = tf.keras.layers.Dense(2, activation='softmax')(embedding)\n", + "model = tf.keras.models.Model(inputs=inputs, outputs=predictions)\n", + "\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f490f43e", + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = create_optimizer(\n", + " params['lr'],\n", + " params['epochs'] * len(ds_train),\n", + " 0,\n", + " weight_decay_rate=0.01,\n", + " optimizer_type='adamw'\n", + ")\n", + "bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)\n", + "pr_auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=1000, from_logits=False)\n", + "model.compile(optimizer=optimizer, loss=bce, metrics=[pr_auc])\n", + "\n", + "callbacks = [\n", + " tf.keras.callbacks.EarlyStopping(\n", + " monitor='val_loss',\n", + " mode='min',\n", + " patience=1,\n", + " restore_best_weights=True\n", + " ),\n", + " tf.keras.callbacks.ModelCheckpoint(\n", + " filepath=os.path.join(log_path, 'checkpoints', '{epoch:02d}'),\n", + " save_freq='epoch'\n", + " ),\n", + " tf.keras.callbacks.TensorBoard(\n", + " log_dir=os.path.join(log_path, 'scalars'),\n", + " update_freq='batch',\n", + " write_graph=False\n", + " )\n", + "]\n", + "history = model.fit(\n", + " ds_train,\n", + " epochs=params['epochs'],\n", + " callbacks=callbacks,\n", + " validation_data=ds_val,\n", + " steps_per_epoch=len(ds_train)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2779f1e1", + "metadata": {}, + "source": [ + "## 3.3. Check output and performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a6383db", + "metadata": {}, + "outputs": [], + "source": [ + "# Expect prediction to be 1\n", + "model.predict([\"xxx 🍑\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0245eb0d", + "metadata": {}, + "outputs": [], + "source": [ + "preds = X_val.processed_text.apply(apply_model)\n", + "print(classification_report(y_val, preds >= 0.90, digits=4))\n", + "\n", + "precision, recall, thresholds = precision_recall_curve(y_val, preds)\n", + "\n", + "fig = plt.figure(figsize=(15, 10))\n", + "plt.plot(precision, recall, lw=2)\n", + "plt.grid()\n", + "plt.xlim(0.2, 1)\n", + "plt.ylim(0.3, 1)\n", + "plt.xlabel(\"Recall\", size=20)\n", + "plt.ylabel(\"Precision\", size=20)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ffb1ba3", + "metadata": {}, + "outputs": [], + "source": [ + "average_precision_score(y_val, preds)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/trust_and_safety_models/nsfw/nsfw_text.py b/trust_and_safety_models/nsfw/nsfw_text.py deleted file mode 100644 index 980fc8fd4..000000000 --- a/trust_and_safety_models/nsfw/nsfw_text.py +++ /dev/null @@ -1,152 +0,0 @@ -from datetime import datetime -from functools import reduce -import os -import pandas as pd -import re -from sklearn.metrics import average_precision_score, classification_report, precision_recall_curve, PrecisionRecallDisplay -from sklearn.model_selection import train_test_split -import tensorflow as tf -import matplotlib.pyplot as plt -import re - -from twitter.cuad.representation.models.optimization import create_optimizer -from twitter.cuad.representation.models.text_encoder import TextEncoder - -pd.set_option('display.max_colwidth', None) -pd.set_option('display.expand_frame_repr', False) - -print(tf.__version__) -print(tf.config.list_physical_devices()) - -log_path = os.path.join('pnsfwtweettext_model_runs', datetime.now().strftime('%Y-%m-%d_%H.%M.%S')) - -tweet_text_feature = 'text' - -params = { - 'batch_size': 32, - 'max_seq_lengths': 256, - 'model_type': 'twitter_bert_base_en_uncased_augmented_mlm', - 'trainable_text_encoder': True, - 'lr': 5e-5, - 'epochs': 10, -} - -REGEX_PATTERNS = [ - r'^RT @[A-Za-z0-9_]+: ', - r"@[A-Za-z0-9_]+", - r'https:\/\/t\.co\/[A-Za-z0-9]{10}', - r'@\?\?\?\?\?', -] - -EMOJI_PATTERN = re.compile( - "([" - "\U0001F1E0-\U0001F1FF" - "\U0001F300-\U0001F5FF" - "\U0001F600-\U0001F64F" - "\U0001F680-\U0001F6FF" - "\U0001F700-\U0001F77F" - "\U0001F780-\U0001F7FF" - "\U0001F800-\U0001F8FF" - "\U0001F900-\U0001F9FF" - "\U0001FA00-\U0001FA6F" - "\U0001FA70-\U0001FAFF" - "\U00002702-\U000027B0" - "])" - ) - -def clean_tweet(text): - for pattern in REGEX_PATTERNS: - text = re.sub(pattern, '', text) - - text = re.sub(EMOJI_PATTERN, r' \1 ', text) - - text = re.sub(r'\n', ' ', text) - - return text.strip().lower() - - -df['processed_text'] = df['text'].astype(str).map(clean_tweet) -df.sample(10) - -X_train, X_val, y_train, y_val = train_test_split(df[['processed_text']], df['is_nsfw'], test_size=0.1, random_state=1) - -def df_to_ds(X, y, shuffle=False): - ds = tf.data.Dataset.from_tensor_slices(( - X.values, - tf.one_hot(tf.cast(y.values, tf.int32), depth=2, axis=-1) - )) - - if shuffle: - ds = ds.shuffle(1000, seed=1, reshuffle_each_iteration=True) - - return ds.map(lambda text, label: ({ tweet_text_feature: text }, label)).batch(params['batch_size']) - -ds_train = df_to_ds(X_train, y_train, shuffle=True) -ds_val = df_to_ds(X_val, y_val) -X_train.values - -inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name=tweet_text_feature) -encoder = TextEncoder( - max_seq_lengths=params['max_seq_lengths'], - model_type=params['model_type'], - trainable=params['trainable_text_encoder'], - local_preprocessor_path='demo-preprocessor' -) -embedding = encoder([inputs])["pooled_output"] -predictions = tf.keras.layers.Dense(2, activation='softmax')(embedding) -model = tf.keras.models.Model(inputs=inputs, outputs=predictions) - -model.summary() - -optimizer = create_optimizer( - params['lr'], - params['epochs'] * len(ds_train), - 0, - weight_decay_rate=0.01, - optimizer_type='adamw' -) -bce = tf.keras.losses.BinaryCrossentropy(from_logits=False) -pr_auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=1000, from_logits=False) -model.compile(optimizer=optimizer, loss=bce, metrics=[pr_auc]) - -callbacks = [ - tf.keras.callbacks.EarlyStopping( - monitor='val_loss', - mode='min', - patience=1, - restore_best_weights=True - ), - tf.keras.callbacks.ModelCheckpoint( - filepath=os.path.join(log_path, 'checkpoints', '{epoch:02d}'), - save_freq='epoch' - ), - tf.keras.callbacks.TensorBoard( - log_dir=os.path.join(log_path, 'scalars'), - update_freq='batch', - write_graph=False - ) -] -history = model.fit( - ds_train, - epochs=params['epochs'], - callbacks=callbacks, - validation_data=ds_val, - steps_per_epoch=len(ds_train) -) - -model.predict(["xxx 🍑"]) - -preds = X_val.processed_text.apply(apply_model) -print(classification_report(y_val, preds >= 0.90, digits=4)) - -precision, recall, thresholds = precision_recall_curve(y_val, preds) - -fig = plt.figure(figsize=(15, 10)) -plt.plot(precision, recall, lw=2) -plt.grid() -plt.xlim(0.2, 1) -plt.ylim(0.3, 1) -plt.xlabel("Recall", size=20) -plt.ylabel("Precision", size=20) - -average_precision_score(y_val, preds)