{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Probabilistic Language Models\n", "\n", "\n", "Feng Li\n", "\n", "School of Statistics and Mathematics\n", "\n", "Central University of Finance and Economics\n", "\n", "[feng.li@cufe.edu.cn](mailto:feng.li@cufe.edu.cn)\n", "\n", "[https://feng.li/python](https://feng.li/python)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Probabilistic Language Models\n", "\n", "\n", "How can we assign a probability to a sentence?\n", "\n", " P(high winds tonight) > P(large winds tonight)\n", " \n", " \n", "How do we do a proper spell correction?\n", "\n", "- The office is about fifteen **minuets** from my house\n", "\n", " P(about fifteen minutes from) > P(about fifteen minuets from)\n", " \n", " \n", "- How do we imporve the precision of speech recognition?\n", "\n", "\n", " P(I saw a van) >> P(eyes awe of an)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## The Goal of a Language Model: \n", "\n", "We compute the probability of a sentence or sequence of words:\n", "\n", "$$P(W) = P(w_1,w_2,w_3,w_4,w_5,...w_n)$$\n", " \n", "Related task: probability of an upcoming word:\n", "\n", "$$P(w_5|w_1,w_2,w_3,w_4)$$\n", " \n", "A model that computes either of these $P(W)$ or $P(w_n|w_1,w_2,...,w_{n-1})$ is called a **language model**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## How to compute this joint probability:\n", "\n", "$$P(its, water, is, so, transparent, that)$$\n", "\n", "- Intuition: let’s rely on the Chain Rule of Probability\n", "\n", "- The Chain Rule in General\n", "\n", "$$P(x_1,x_2,x_3,...,x_n) = P(x_1)P(x_2|x_1)P(x_3|x_1,x_2)...P(x_n|x_1,…,x_{n-1})$$\n", "\n", "- The Joint probability is now factorized as \n", "\n", "$$P(“its~water~is~so~transparent”) =\n", "\tP(its) × P(water|its) × P(is|its~water) \n", " × P(so|its~water~is) × P(transparent|its~water~is~so)\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Simplifying assumption:\n", "\n", "- Markov Assumption\n", "\n", "$$\n", "P(w_1,w_2,...,w_{n}) \\approx \\prod_{i=1}^n P(w_i|w_{i-1},...,w_{i-k})\n", "$$\n", "where $k$ is some positive integer.\n", "\n", "- In other words, we approximate each component in the product\n", "\n", "$$\n", " P(w_i |w_{i-1},...,w_{1}) \\approx P(w_i|w_{i-k},...,w_{i-1})\n", "$$" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from nltk.tokenize import RegexpTokenizer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "os.chdir(os.path.dirname(os.path.realpath('__file__')))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "merged = pd.read_excel('data/guba.xlsx', sheet_name='Merged')\n", "\n", "merged = merged['Explanation'].tolist()\n", "tokenizer = RegexpTokenizer(r'\\w+')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# Tokenize\n", "\n", "merged = [tokenizer.tokenize(merged[i]) for i in range(len(merged))]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# stop words\n", "def word_clean(sentence, stop_words):\n", " sentence = [i.lower() for i in sentence]\n", " sentence = [token for token in sentence if not token.isnumeric()]\n", " sentence = [j for j in sentence if j not in stop_words]\n", " return sentence\n", "\n", "stop_words = pd.read_csv('data/stopwords.txt', header=None)[0].to_list()\n", "merged = [word_clean(sentence, stop_words) for sentence in merged]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "import nltk\n", "from nltk.stem.wordnet import WordNetLemmatizer" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# Lemmatize the documents.\n", "\n", "lemmatizer = WordNetLemmatizer()\n", "merged = [[lemmatizer.lemmatize(token) for token in doc] for doc in merged]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# Compute bigrams.\n", "from gensim.models import Phrases\n", "\n", "# Add bigrams and trigrams to docs\n", "bigram = Phrases(merged, min_count=1)\n", "merged = [bigram[lst] for lst in merged]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# Remove rare and common tokens.\n", "from gensim.corpora import Dictionary\n", "\n", "# Create a dictionary representation of the documents.\n", "merged_dictionary = Dictionary(merged)\n", "\n", "# Filter out words that occur less than 2 documents, or more than 60% of the documents.\n", "merged_dictionary.filter_extremes(no_below=1, no_above=0.6)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# Bag-of-words representation of the documents.\n", "\n", "merged_corpus = [merged_dictionary.doc2bow(doc) for doc in merged]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# Train merged model.\n", "from gensim.models import LdaModel\n", "\n", "# Set training parameters.\n", "num_topics = 5\n", "chunksize = 20\n", "passes = 5\n", "iterations = 200\n", "eval_every = None # Don't evaluate model perplexity, takes too much time.\n", "\n", "# Make an index to word dictionary.\n", "temp = merged_dictionary[0] # This is only to \"load\" the dictionary.\n", "id2word = merged_dictionary.id2token\n", "\n", "merged_model = LdaModel(\n", " corpus=merged_corpus,\n", " id2word=id2word,\n", " chunksize=chunksize,\n", " alpha='auto',\n", " eta='auto', \n", " iterations=iterations,\n", " num_topics=num_topics,\n", " passes=passes,\n", " eval_every=eval_every\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://mirrors.163.com/pypi/simple/\n", "Collecting pyLDAvis\n", " Using cached pyLDAvis-3.3.1-py2.py3-none-any.whl\n", "Collecting joblib\n", " Using cached https://mirrors.163.com/pypi/packages/55/85/70c6602b078bd9e6f3da4f467047e906525c355a4dacd4f71b97a35d9897/joblib-1.0.1-py3-none-any.whl (303 kB)\n", "Collecting pandas>=1.2.0\n", " Using cached https://mirrors.163.com/pypi/packages/48/b4/1081d66b71c4dfc1bc1e19d6f2abbf93ed42f69df7703eb323742d45423e/pandas-1.3.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.5 MB)\n", "Collecting future\n", " Using cached future-0.18.2-py3-none-any.whl\n", "Collecting sklearn\n", " Using cached sklearn-0.0-py2.py3-none-any.whl\n", "Collecting setuptools\n", " Using cached https://mirrors.163.com/pypi/packages/41/f4/a7ca4859317232b1efb64a826b8d2d7299bb77fb60bdb08e2bd1d61cf80d/setuptools-58.2.0-py3-none-any.whl (946 kB)\n", "Collecting gensim\n", " Using cached https://mirrors.163.com/pypi/packages/61/e8/ddf62a31b4f97f543a38233047865d02be97c192f7f8d849bbf3353bc094/gensim-4.1.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (24.0 MB)\n", "Collecting numexpr\n", " Using cached https://mirrors.163.com/pypi/packages/4e/88/ccd8973d0dde04e95f6fbc7818f770a18293de7348fc3f9b66e9bf44a2f9/numexpr-2.7.3-cp39-cp39-manylinux2010_x86_64.whl (471 kB)\n", "Collecting numpy>=1.20.0\n", " Using cached https://mirrors.163.com/pypi/packages/7a/4c/dd00ce768b0f0f7de5c486cbd9f5b922bc3af2f3a5da30121d7f7dc03130/numpy-1.21.2-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (15.8 MB)\n", "Collecting scikit-learn\n", " Using cached https://mirrors.163.com/pypi/packages/53/8b/99d0658d74a2e6277dbe40b6759581badb2790f6422369ae6a3d606b9164/scikit_learn-1.0.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (24.7 MB)\n", "Collecting jinja2\n", " Using cached https://mirrors.163.com/pypi/packages/94/42/d8bca8e99789bcc35dfa9b03acaa8b518720d6e060163745bc2bf2ead842/Jinja2-3.0.2-py3-none-any.whl (133 kB)\n", "Collecting scipy\n", " Using cached https://mirrors.163.com/pypi/packages/83/4a/13f813a86b7f510954bdae649f0fda718e8210320b4108656ddaf96442c9/scipy-1.7.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (39.8 MB)\n", "Collecting funcy\n", " Using cached https://mirrors.163.com/pypi/packages/66/89/479de0afbbfb98d1c4b887936808764627300208bb771fcd823403645a36/funcy-1.15-py2.py3-none-any.whl (32 kB)\n", "Collecting python-dateutil>=2.7.3\n", " Using cached https://mirrors.163.com/pypi/packages/d4/70/d60450c3dd48ef87586924207ae8907090de0b306af2bce5d134d78615cb/python_dateutil-2.8.1-py2.py3-none-any.whl (227 kB)\n", "Collecting pytz>=2017.3\n", " Using cached https://mirrors.163.com/pypi/packages/70/94/784178ca5dd892a98f113cdd923372024dc04b8d40abe77ca76b5fb90ca6/pytz-2021.1-py2.py3-none-any.whl (510 kB)\n", "Collecting six>=1.5\n", " Using cached https://mirrors.163.com/pypi/packages/ee/ff/48bde5c0f013094d729fe4b0316ba2a24774b3ff1c52d924a8a4cb04078a/six-1.15.0-py2.py3-none-any.whl (10 kB)\n", "Collecting smart-open>=1.8.1\n", " Using cached https://mirrors.163.com/pypi/packages/cd/11/05f68ea934c24ade38e95ac30a38407767787c4e3db1776eae4886ad8c95/smart_open-5.2.1-py3-none-any.whl (58 kB)\n", "Collecting MarkupSafe>=2.0\n", " Using cached https://mirrors.163.com/pypi/packages/c2/db/314df69668f582d5173922bded7b58126044bb77cfce6347c5d992074d2e/MarkupSafe-2.0.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (30 kB)\n", "Collecting threadpoolctl>=2.0.0\n", " Using cached https://mirrors.163.com/pypi/packages/f7/12/ec3f2e203afa394a149911729357aa48affc59c20e2c1c8297a60f33f133/threadpoolctl-2.1.0-py3-none-any.whl (12 kB)\n", "Installing collected packages: numpy, threadpoolctl, six, scipy, joblib, smart-open, scikit-learn, pytz, python-dateutil, MarkupSafe, sklearn, setuptools, pandas, numexpr, jinja2, gensim, future, funcy, pyLDAvis\n", "Successfully installed MarkupSafe-2.0.1 funcy-1.15 future-0.18.2 gensim-4.1.2 jinja2-3.0.2 joblib-1.0.1 numexpr-2.7.3 numpy-1.21.2 pandas-1.3.4 pyLDAvis-3.3.1 python-dateutil-2.8.1 pytz-2021.1 scikit-learn-1.0.1 scipy-1.7.2 setuptools-58.2.0 six-1.15.0 sklearn-0.0 smart-open-5.2.1 threadpoolctl-2.1.0\n" ] } ], "source": [ "# We need a specific version of pyLDAvis. Let's install it to current notebook directory.\n", "! pip3 install pyLDAvis -I -t modules" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "['modules',\n", " 'modules',\n", " '/home/fli/cloud/teaching/python/python-slides/P08-Advanced-Topics',\n", " '/usr/lib/python39.zip',\n", " '/usr/lib/python3.9',\n", " '/usr/lib/python3.9/lib-dynload',\n", " '',\n", " '/home/fli/.local/lib/python3.9/site-packages',\n", " '/usr/local/lib/python3.9/dist-packages',\n", " '/usr/lib/python3/dist-packages',\n", " '/usr/local/lib/python3.9/dist-packages/IPython/extensions',\n", " '/home/fli/.ipython']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# prepend the `modules` folder to Python's search path \n", "import sys\n", "sys.path.insert(0, 'modules')\n", "sys.path" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "['modules',\n", " '/home/fli/cloud/teaching/python/python-slides/P08-Advanced-Topics',\n", " '/usr/lib/python39.zip',\n", " '/usr/lib/python3.9',\n", " '/usr/lib/python3.9/lib-dynload',\n", " '',\n", " '/home/fli/.local/lib/python3.9/site-packages',\n", " '/usr/local/lib/python3.9/dist-packages',\n", " '/usr/lib/python3/dist-packages',\n", " '/usr/local/lib/python3.9/dist-packages/IPython/extensions',\n", " '/home/fli/.ipython']" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sys.path" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/fli/cloud/teaching/python/python-slides/P08-Advanced-Topics/modules/pyLDAvis/_prepare.py:246: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only\n", " default_term_info = default_term_info.sort_values(\n" ] } ], "source": [ "import pyLDAvis\n", "from pyLDAvis import gensim_models\n", "vis= gensim_models.prepare(merged_model, merged_corpus, dictionary=merged_dictionary)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "pyLDAvis.save_html(vis, 'lda.html')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | Term | \n", "Freq | \n", "Total | \n", "Category | \n", "logprob | \n", "loglift | \n", "relevance | \n", "
---|---|---|---|---|---|---|---|
377 | \n", "product | \n", "23.913529 | \n", "30.062396 | \n", "Topic1 | \n", "-3.2928 | \n", "0.9540 | \n", "-3.2928 | \n", "
627 | \n", "real_estate | \n", "10.393194 | \n", "12.188516 | \n", "Topic1 | \n", "-4.1261 | \n", "1.0235 | \n", "-4.1261 | \n", "
246 | \n", "company | \n", "7.408749 | \n", "9.892665 | \n", "Topic1 | \n", "-4.4646 | \n", "0.8937 | \n", "-4.4646 | \n", "
309 | \n", "financial | \n", "6.197204 | \n", "8.785535 | \n", "Topic1 | \n", "-4.6432 | \n", "0.8339 | \n", "-4.6432 | \n", "
800 | \n", "zhangjiang_tech | \n", "5.380072 | \n", "5.830456 | \n", "Topic1 | \n", "-4.7846 | \n", "1.1025 | \n", "-4.7846 | \n", "
558 | \n", "expected_return | \n", "5.316008 | \n", "5.792948 | \n", "Topic1 | \n", "-4.7966 | \n", "1.0970 | \n", "-4.7966 | \n", "
520 | \n", "bank | \n", "4.582158 | \n", "5.053505 | \n", "Topic1 | \n", "-4.9451 | \n", "1.0850 | \n", "-4.9451 | \n", "
251 | \n", "future | \n", "4.360551 | \n", "4.810861 | \n", "Topic1 | \n", "-4.9947 | \n", "1.0846 | \n", "-4.9947 | \n", "
102 | \n", "stock | \n", "3.904119 | \n", "11.538281 | \n", "Topic1 | \n", "-5.1053 | \n", "0.0992 | \n", "-5.1053 | \n", "
540 | \n", "customer | \n", "3.847856 | \n", "4.313823 | \n", "Topic1 | \n", "-5.1198 | \n", "1.0686 | \n", "-5.1198 | \n", "
265 | \n", "term | \n", "3.805859 | \n", "5.013711 | \n", "Topic1 | \n", "-5.1307 | \n", "0.9072 | \n", "-5.1307 | \n", "
619 | \n", "project | \n", "3.664846 | \n", "5.127102 | \n", "Topic1 | \n", "-5.1685 | \n", "0.8471 | \n", "-5.1685 | \n", "
245 | \n", "business | \n", "3.434275 | \n", "6.798940 | \n", "Topic1 | \n", "-5.2335 | \n", "0.4999 | \n", "-5.2335 | \n", "
567 | \n", "fund | \n", "3.249944 | \n", "6.864512 | \n", "Topic1 | \n", "-5.2887 | \n", "0.4352 | \n", "-5.2887 | \n", "
582 | \n", "invest | \n", "3.138705 | \n", "3.589726 | \n", "Topic1 | \n", "-5.3235 | \n", "1.0486 | \n", "-5.3235 | \n", "
636 | \n", "return | \n", "3.127770 | \n", "3.583477 | \n", "Topic1 | \n", "-5.3270 | \n", "1.0469 | \n", "-5.3270 | \n", "
660 | \n", "trust | \n", "3.127770 | \n", "3.583477 | \n", "Topic1 | \n", "-5.3270 | \n", "1.0469 | \n", "-5.3270 | \n", "
562 | \n", "financing | \n", "3.112968 | \n", "3.573827 | \n", "Topic1 | \n", "-5.3317 | \n", "1.0448 | \n", "-5.3317 | \n", "
638 | \n", "risk | \n", "3.112968 | \n", "3.573827 | \n", "Topic1 | \n", "-5.3317 | \n", "1.0448 | \n", "-5.3317 | \n", "
666 | \n", "wealth_management | \n", "3.112968 | \n", "3.573827 | \n", "Topic1 | \n", "-5.3317 | \n", "1.0448 | \n", "-5.3317 | \n", "
241 | \n", "annual | \n", "2.878278 | \n", "3.321743 | \n", "Topic1 | \n", "-5.4101 | \n", "1.0396 | \n", "-5.4101 | \n", "
782 | \n", "board | \n", "2.854602 | \n", "4.410119 | \n", "Topic1 | \n", "-5.4184 | \n", "0.7479 | \n", "-5.4184 | \n", "
24 | \n", "control | \n", "2.583636 | \n", "3.367196 | \n", "Topic1 | \n", "-5.5181 | \n", "0.9180 | \n", "-5.5181 | \n", "
799 | \n", "zhangjiang | \n", "2.400752 | \n", "2.847076 | \n", "Topic1 | \n", "-5.5915 | \n", "1.0124 | \n", "-5.5915 | \n", "
790 | \n", "park | \n", "2.400752 | \n", "2.847076 | \n", "Topic1 | \n", "-5.5915 | \n", "1.0124 | \n", "-5.5915 | \n", "
655 | \n", "threshold | \n", "2.396182 | \n", "2.844386 | \n", "Topic1 | \n", "-5.5934 | \n", "1.0114 | \n", "-5.5934 | \n", "
625 | \n", "rate | \n", "2.388913 | \n", "2.841308 | \n", "Topic1 | \n", "-5.5965 | \n", "1.0095 | \n", "-5.5965 | \n", "
661 | \n", "type | \n", "2.388913 | \n", "2.841308 | \n", "Topic1 | \n", "-5.5965 | \n", "1.0095 | \n", "-5.5965 | \n", "
565 | \n", "focus | \n", "2.377280 | \n", "2.833401 | \n", "Topic1 | \n", "-5.6013 | \n", "1.0074 | \n", "-5.6013 | \n", "
665 | \n", "wealth | \n", "2.377280 | \n", "2.833401 | \n", "Topic1 | \n", "-5.6013 | \n", "1.0074 | \n", "-5.6013 | \n", "
604 | \n", "ningbo_branch | \n", "2.377280 | \n", "2.833401 | \n", "Topic1 | \n", "-5.6013 | \n", "1.0074 | \n", "-5.6013 | \n", "
603 | \n", "ningbo | \n", "2.377280 | \n", "2.833401 | \n", "Topic1 | \n", "-5.6013 | \n", "1.0074 | \n", "-5.6013 | \n", "
523 | \n", "buy_product | \n", "2.377280 | \n", "2.833401 | \n", "Topic1 | \n", "-5.6013 | \n", "1.0074 | \n", "-5.6013 | \n", "
518 | \n", "average_annualized | \n", "2.377280 | \n", "2.833401 | \n", "Topic1 | \n", "-5.6013 | \n", "1.0074 | \n", "-5.6013 | \n", "
766 | \n", "size | \n", "2.358410 | \n", "2.833941 | \n", "Topic1 | \n", "-5.6093 | \n", "0.9992 | \n", "-5.6093 | \n", "
721 | \n", "infrastructure | \n", "2.358410 | \n", "2.833941 | \n", "Topic1 | \n", "-5.6093 | \n", "0.9992 | \n", "-5.6093 | \n", "
716 | \n", "huge | \n", "2.358410 | \n", "2.833941 | \n", "Topic1 | \n", "-5.6093 | \n", "0.9992 | \n", "-5.6093 | \n", "
672 | \n", "anxin | \n", "2.358410 | \n", "2.833941 | \n", "Topic1 | \n", "-5.6093 | \n", "0.9992 | \n", "-5.6093 | \n", "
717 | \n", "improve | \n", "2.358410 | \n", "2.833941 | \n", "Topic1 | \n", "-5.6093 | \n", "0.9992 | \n", "-5.6093 | \n", "
794 | \n", "senior_executive | \n", "2.299076 | \n", "2.742155 | \n", "Topic1 | \n", "-5.6348 | \n", "1.0066 | \n", "-5.6348 | \n", "