mirror of
https://github.com/twitter/the-algorithm.git
synced 2025-01-05 09:01:54 +01:00
Compare commits
2 Commits
0791cdd822
...
a1d7698eea
Author | SHA1 | Date | |
---|---|---|---|
|
a1d7698eea | ||
|
49f9f7fd82 |
348
trust_and_safety_models/nsfw/NSFW_text.ipynb
Normal file
348
trust_and_safety_models/nsfw/NSFW_text.ipynb
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "faab08b8",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# NSFW text"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b919917f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Builds a TensorFlow model that predicts the probability of a text to be NSFW."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7ee52623",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from datetime import datetime\n",
|
||||||
|
"from functools import reduce\n",
|
||||||
|
"import os\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import re\n",
|
||||||
|
"from sklearn.metrics import (\n",
|
||||||
|
" average_precision_score,\n",
|
||||||
|
" classification_report, \n",
|
||||||
|
" precision_recall_curve,\n",
|
||||||
|
" PrecisionRecallDisplay,\n",
|
||||||
|
")\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"import tensorflow as tf\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import re\n",
|
||||||
|
"\n",
|
||||||
|
"from twitter.cuad.representation.models.optimization import create_optimizer #TODO: provide instruction for install\n",
|
||||||
|
"from twitter.cuad.representation.models.text_encoder import TextEncoder #TODO: provide instruction for install\n",
|
||||||
|
"\n",
|
||||||
|
"pd.set_option('display.max_colwidth', None)\n",
|
||||||
|
"pd.set_option('display.expand_frame_repr', False)\n",
|
||||||
|
"\n",
|
||||||
|
"print(tf.__version__)\n",
|
||||||
|
"print(tf.config.list_physical_devices())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "64a45fd4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# 1. Parameters and helper function"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ce628e8d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"log_path = os.path.join('pnsfwtweettext_model_runs', datetime.now().strftime('%Y-%m-%d_%H.%M.%S'))\n",
|
||||||
|
"\n",
|
||||||
|
"tweet_text_feature = 'text'\n",
|
||||||
|
"\n",
|
||||||
|
"params = {\n",
|
||||||
|
" 'batch_size': 32,\n",
|
||||||
|
" 'max_seq_lengths': 256,\n",
|
||||||
|
" 'model_type': 'twitter_bert_base_en_uncased_augmented_mlm',\n",
|
||||||
|
" 'trainable_text_encoder': True,\n",
|
||||||
|
" 'lr': 5e-5,\n",
|
||||||
|
" 'epochs': 10,\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"REGEX_PATTERNS = [\n",
|
||||||
|
" r'^RT @[A-Za-z0-9_]+: ', \n",
|
||||||
|
" r\"@[A-Za-z0-9_]+\",\n",
|
||||||
|
" r'https:\\/\\/t\\.co\\/[A-Za-z0-9]{10}',\n",
|
||||||
|
" r'@\\?\\?\\?\\?\\?',\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"EMOJI_PATTERN = re.compile(\n",
|
||||||
|
" \"([\"\n",
|
||||||
|
" \"\\U0001F1E0-\\U0001F1FF\"\n",
|
||||||
|
" \"\\U0001F300-\\U0001F5FF\"\n",
|
||||||
|
" \"\\U0001F600-\\U0001F64F\"\n",
|
||||||
|
" \"\\U0001F680-\\U0001F6FF\"\n",
|
||||||
|
" \"\\U0001F700-\\U0001F77F\"\n",
|
||||||
|
" \"\\U0001F780-\\U0001F7FF\"\n",
|
||||||
|
" \"\\U0001F800-\\U0001F8FF\"\n",
|
||||||
|
" \"\\U0001F900-\\U0001F9FF\"\n",
|
||||||
|
" \"\\U0001FA00-\\U0001FA6F\"\n",
|
||||||
|
" \"\\U0001FA70-\\U0001FAFF\"\n",
|
||||||
|
" \"\\U00002702-\\U000027B0\"\n",
|
||||||
|
" \"])\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"def clean_tweet(text):\n",
|
||||||
|
" for pattern in REGEX_PATTERNS:\n",
|
||||||
|
" text = re.sub(pattern, '', text)\n",
|
||||||
|
"\n",
|
||||||
|
" text = re.sub(EMOJI_PATTERN, r' \\1 ', text)\n",
|
||||||
|
" \n",
|
||||||
|
" text = re.sub(r'\\n', ' ', text)\n",
|
||||||
|
" \n",
|
||||||
|
" return text.strip().lower()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b55ad56e",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# 2. Create data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c491ecf7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Sample data just for demonstration\n",
|
||||||
|
"\n",
|
||||||
|
"df = pd.DataFrame(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"text\": [\n",
|
||||||
|
" \"Vanilla Tweet\",\n",
|
||||||
|
" \"Love is Blind season 4 is 🔥\",\n",
|
||||||
|
" \"I like Apples\",\n",
|
||||||
|
" \"Big dick for you\",\n",
|
||||||
|
" \"Wanna see my puss*?\",\n",
|
||||||
|
" \"Elon Musk is the best\",\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"is_nsfw\": [\n",
|
||||||
|
" 0, 0, 0, 1, 1, 1,\n",
|
||||||
|
" ]\n",
|
||||||
|
" }\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"df"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e2369e2f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# 3. Train model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "83a2d177",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 3.1. Preprocessing"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7843012c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df['processed_text'] = df['text'].astype(str).map(clean_tweet)\n",
|
||||||
|
"df.sample(10)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e3e5ceb5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_train, X_val, y_train, y_val = train_test_split(df[['processed_text']], df['is_nsfw'], test_size=0.1, random_state=1)\n",
|
||||||
|
"\n",
|
||||||
|
"def df_to_ds(X, y, shuffle=False):\n",
|
||||||
|
" ds = tf.data.Dataset.from_tensor_slices((\n",
|
||||||
|
" X.values,\n",
|
||||||
|
" tf.one_hot(tf.cast(y.values, tf.int32), depth=2, axis=-1)\n",
|
||||||
|
" ))\n",
|
||||||
|
" \n",
|
||||||
|
" if shuffle:\n",
|
||||||
|
" ds = ds.shuffle(1000, seed=1, reshuffle_each_iteration=True)\n",
|
||||||
|
" \n",
|
||||||
|
" return ds.map(lambda text, label: ({ tweet_text_feature: text }, label)).batch(params['batch_size'])\n",
|
||||||
|
"\n",
|
||||||
|
"ds_train = df_to_ds(X_train, y_train, shuffle=True)\n",
|
||||||
|
"ds_val = df_to_ds(X_val, y_val)\n",
|
||||||
|
"X_train.values\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "974c1248",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 3.2. Training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "dfd6fc8f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name=tweet_text_feature)\n",
|
||||||
|
"encoder = TextEncoder(\n",
|
||||||
|
" max_seq_lengths=params['max_seq_lengths'],\n",
|
||||||
|
" model_type=params['model_type'],\n",
|
||||||
|
" trainable=params['trainable_text_encoder'],\n",
|
||||||
|
" local_preprocessor_path='demo-preprocessor'\n",
|
||||||
|
")\n",
|
||||||
|
"embedding = encoder([inputs])[\"pooled_output\"]\n",
|
||||||
|
"predictions = tf.keras.layers.Dense(2, activation='softmax')(embedding)\n",
|
||||||
|
"model = tf.keras.models.Model(inputs=inputs, outputs=predictions)\n",
|
||||||
|
"\n",
|
||||||
|
"model.summary()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f490f43e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"optimizer = create_optimizer(\n",
|
||||||
|
" params['lr'],\n",
|
||||||
|
" params['epochs'] * len(ds_train),\n",
|
||||||
|
" 0,\n",
|
||||||
|
" weight_decay_rate=0.01,\n",
|
||||||
|
" optimizer_type='adamw'\n",
|
||||||
|
")\n",
|
||||||
|
"bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)\n",
|
||||||
|
"pr_auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=1000, from_logits=False)\n",
|
||||||
|
"model.compile(optimizer=optimizer, loss=bce, metrics=[pr_auc])\n",
|
||||||
|
"\n",
|
||||||
|
"callbacks = [\n",
|
||||||
|
" tf.keras.callbacks.EarlyStopping(\n",
|
||||||
|
" monitor='val_loss',\n",
|
||||||
|
" mode='min',\n",
|
||||||
|
" patience=1,\n",
|
||||||
|
" restore_best_weights=True\n",
|
||||||
|
" ),\n",
|
||||||
|
" tf.keras.callbacks.ModelCheckpoint(\n",
|
||||||
|
" filepath=os.path.join(log_path, 'checkpoints', '{epoch:02d}'),\n",
|
||||||
|
" save_freq='epoch'\n",
|
||||||
|
" ),\n",
|
||||||
|
" tf.keras.callbacks.TensorBoard(\n",
|
||||||
|
" log_dir=os.path.join(log_path, 'scalars'),\n",
|
||||||
|
" update_freq='batch',\n",
|
||||||
|
" write_graph=False\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"history = model.fit(\n",
|
||||||
|
" ds_train,\n",
|
||||||
|
" epochs=params['epochs'],\n",
|
||||||
|
" callbacks=callbacks,\n",
|
||||||
|
" validation_data=ds_val,\n",
|
||||||
|
" steps_per_epoch=len(ds_train)\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "2779f1e1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 3.3. Check output and performance"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8a6383db",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Expect prediction to be 1\n",
|
||||||
|
"model.predict([\"xxx 🍑\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0245eb0d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"preds = X_val.processed_text.apply(apply_model)\n",
|
||||||
|
"print(classification_report(y_val, preds >= 0.90, digits=4))\n",
|
||||||
|
"\n",
|
||||||
|
"precision, recall, thresholds = precision_recall_curve(y_val, preds)\n",
|
||||||
|
"\n",
|
||||||
|
"fig = plt.figure(figsize=(15, 10))\n",
|
||||||
|
"plt.plot(precision, recall, lw=2)\n",
|
||||||
|
"plt.grid()\n",
|
||||||
|
"plt.xlim(0.2, 1)\n",
|
||||||
|
"plt.ylim(0.3, 1)\n",
|
||||||
|
"plt.xlabel(\"Recall\", size=20)\n",
|
||||||
|
"plt.ylabel(\"Precision\", size=20)\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0ffb1ba3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"average_precision_score(y_val, preds)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -1,152 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from functools import reduce
|
|
||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import re
|
|
||||||
from sklearn.metrics import average_precision_score, classification_report, precision_recall_curve, PrecisionRecallDisplay
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
import tensorflow as tf
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import re
|
|
||||||
|
|
||||||
from twitter.cuad.representation.models.optimization import create_optimizer
|
|
||||||
from twitter.cuad.representation.models.text_encoder import TextEncoder
|
|
||||||
|
|
||||||
pd.set_option('display.max_colwidth', None)
|
|
||||||
pd.set_option('display.expand_frame_repr', False)
|
|
||||||
|
|
||||||
print(tf.__version__)
|
|
||||||
print(tf.config.list_physical_devices())
|
|
||||||
|
|
||||||
log_path = os.path.join('pnsfwtweettext_model_runs', datetime.now().strftime('%Y-%m-%d_%H.%M.%S'))
|
|
||||||
|
|
||||||
tweet_text_feature = 'text'
|
|
||||||
|
|
||||||
params = {
|
|
||||||
'batch_size': 32,
|
|
||||||
'max_seq_lengths': 256,
|
|
||||||
'model_type': 'twitter_bert_base_en_uncased_augmented_mlm',
|
|
||||||
'trainable_text_encoder': True,
|
|
||||||
'lr': 5e-5,
|
|
||||||
'epochs': 10,
|
|
||||||
}
|
|
||||||
|
|
||||||
REGEX_PATTERNS = [
|
|
||||||
r'^RT @[A-Za-z0-9_]+: ',
|
|
||||||
r"@[A-Za-z0-9_]+",
|
|
||||||
r'https:\/\/t\.co\/[A-Za-z0-9]{10}',
|
|
||||||
r'@\?\?\?\?\?',
|
|
||||||
]
|
|
||||||
|
|
||||||
EMOJI_PATTERN = re.compile(
|
|
||||||
"(["
|
|
||||||
"\U0001F1E0-\U0001F1FF"
|
|
||||||
"\U0001F300-\U0001F5FF"
|
|
||||||
"\U0001F600-\U0001F64F"
|
|
||||||
"\U0001F680-\U0001F6FF"
|
|
||||||
"\U0001F700-\U0001F77F"
|
|
||||||
"\U0001F780-\U0001F7FF"
|
|
||||||
"\U0001F800-\U0001F8FF"
|
|
||||||
"\U0001F900-\U0001F9FF"
|
|
||||||
"\U0001FA00-\U0001FA6F"
|
|
||||||
"\U0001FA70-\U0001FAFF"
|
|
||||||
"\U00002702-\U000027B0"
|
|
||||||
"])"
|
|
||||||
)
|
|
||||||
|
|
||||||
def clean_tweet(text):
|
|
||||||
for pattern in REGEX_PATTERNS:
|
|
||||||
text = re.sub(pattern, '', text)
|
|
||||||
|
|
||||||
text = re.sub(EMOJI_PATTERN, r' \1 ', text)
|
|
||||||
|
|
||||||
text = re.sub(r'\n', ' ', text)
|
|
||||||
|
|
||||||
return text.strip().lower()
|
|
||||||
|
|
||||||
|
|
||||||
df['processed_text'] = df['text'].astype(str).map(clean_tweet)
|
|
||||||
df.sample(10)
|
|
||||||
|
|
||||||
X_train, X_val, y_train, y_val = train_test_split(df[['processed_text']], df['is_nsfw'], test_size=0.1, random_state=1)
|
|
||||||
|
|
||||||
def df_to_ds(X, y, shuffle=False):
|
|
||||||
ds = tf.data.Dataset.from_tensor_slices((
|
|
||||||
X.values,
|
|
||||||
tf.one_hot(tf.cast(y.values, tf.int32), depth=2, axis=-1)
|
|
||||||
))
|
|
||||||
|
|
||||||
if shuffle:
|
|
||||||
ds = ds.shuffle(1000, seed=1, reshuffle_each_iteration=True)
|
|
||||||
|
|
||||||
return ds.map(lambda text, label: ({ tweet_text_feature: text }, label)).batch(params['batch_size'])
|
|
||||||
|
|
||||||
ds_train = df_to_ds(X_train, y_train, shuffle=True)
|
|
||||||
ds_val = df_to_ds(X_val, y_val)
|
|
||||||
X_train.values
|
|
||||||
|
|
||||||
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string, name=tweet_text_feature)
|
|
||||||
encoder = TextEncoder(
|
|
||||||
max_seq_lengths=params['max_seq_lengths'],
|
|
||||||
model_type=params['model_type'],
|
|
||||||
trainable=params['trainable_text_encoder'],
|
|
||||||
local_preprocessor_path='demo-preprocessor'
|
|
||||||
)
|
|
||||||
embedding = encoder([inputs])["pooled_output"]
|
|
||||||
predictions = tf.keras.layers.Dense(2, activation='softmax')(embedding)
|
|
||||||
model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
|
|
||||||
|
|
||||||
model.summary()
|
|
||||||
|
|
||||||
optimizer = create_optimizer(
|
|
||||||
params['lr'],
|
|
||||||
params['epochs'] * len(ds_train),
|
|
||||||
0,
|
|
||||||
weight_decay_rate=0.01,
|
|
||||||
optimizer_type='adamw'
|
|
||||||
)
|
|
||||||
bce = tf.keras.losses.BinaryCrossentropy(from_logits=False)
|
|
||||||
pr_auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=1000, from_logits=False)
|
|
||||||
model.compile(optimizer=optimizer, loss=bce, metrics=[pr_auc])
|
|
||||||
|
|
||||||
callbacks = [
|
|
||||||
tf.keras.callbacks.EarlyStopping(
|
|
||||||
monitor='val_loss',
|
|
||||||
mode='min',
|
|
||||||
patience=1,
|
|
||||||
restore_best_weights=True
|
|
||||||
),
|
|
||||||
tf.keras.callbacks.ModelCheckpoint(
|
|
||||||
filepath=os.path.join(log_path, 'checkpoints', '{epoch:02d}'),
|
|
||||||
save_freq='epoch'
|
|
||||||
),
|
|
||||||
tf.keras.callbacks.TensorBoard(
|
|
||||||
log_dir=os.path.join(log_path, 'scalars'),
|
|
||||||
update_freq='batch',
|
|
||||||
write_graph=False
|
|
||||||
)
|
|
||||||
]
|
|
||||||
history = model.fit(
|
|
||||||
ds_train,
|
|
||||||
epochs=params['epochs'],
|
|
||||||
callbacks=callbacks,
|
|
||||||
validation_data=ds_val,
|
|
||||||
steps_per_epoch=len(ds_train)
|
|
||||||
)
|
|
||||||
|
|
||||||
model.predict(["xxx 🍑"])
|
|
||||||
|
|
||||||
preds = X_val.processed_text.apply(apply_model)
|
|
||||||
print(classification_report(y_val, preds >= 0.90, digits=4))
|
|
||||||
|
|
||||||
precision, recall, thresholds = precision_recall_curve(y_val, preds)
|
|
||||||
|
|
||||||
fig = plt.figure(figsize=(15, 10))
|
|
||||||
plt.plot(precision, recall, lw=2)
|
|
||||||
plt.grid()
|
|
||||||
plt.xlim(0.2, 1)
|
|
||||||
plt.ylim(0.3, 1)
|
|
||||||
plt.xlabel("Recall", size=20)
|
|
||||||
plt.ylabel("Precision", size=20)
|
|
||||||
|
|
||||||
average_precision_score(y_val, preds)
|
|
Loading…
Reference in New Issue
Block a user