11 Feb 2024 - tsp
Last update 16 Feb 2024
45 mins
Another old notebook from my playground that got a little bit polished and transformed into a blog post. This time I’m going to take a look at different ways to generate sentence (and a quick glance on word) embeddings. This is often used for search and data mining in unstructured data. The idea behind word and sentence embeddings is to identify a single word or a sequence of words (a sentence) with a vector in an abstract n dimensional vector space where either the distance in this vector space describes similarity or one uses a measure like cosine similarity to compare vectors to each other (i.e. see how one vector projects on the other - the better this projection works the more similar sentences have to be) to identify similar segments of text. To generate those embeddings different methods do exist.
In the beginning of the history of word embeddings in context of natural language processing and also in it’s simplest method - by the way the same methods are also used in image processing, it’s not restricted to natural language - bag of words each word in a vocabulary got assigned a unique random ID. The simplest word embedding vector now is a sequence of the IDs corresponding to the words. This allows quick search for keywords but no contextual identification of the word usage nor the discrimination of homonyms. Also it’s not possible to detect words with similar meaning. Bag of words is still often used for image classification and it’s one of the easiest classifiers to build when getting started with machine learning.
To counter those limitations the concept has been extended into continuous vector spaces. The first historical step has been continuous bag of words - in this case one trains the values to maximize predictability based on words before and after a removed 1-gram. Such maps are usually trained using gradient descent based methods like the backpropagation through time (BPTT) that is a backpropagation algorithm that has been tailored for sequential data. They’re still not fully capable of handling homonyms correctly so the next step has been to develop some kind of contextual word embedding. Typical methods have been TagLM (cascaded LSTM networks whose output is convoluted - one mapping the context and the other the word itself). The sequential processing unfortunately limits the efficiency, particularly for longer text segments.
Since training of such methods is cumbersome the world switched to usage of Transformers. Those are based on multi-head attention mechanism- usually using softmax based distance scaled attention heads - s and started to emerge around 2017. Different than LSTMs it does not contain recurring structures which allows one to speed up training and also processing since they process entire text segments which means they’re also able to catch more nuances of context. They extend the typical word embedding and include also dimensional reduction - in addition the attention mechanism provides an amplification of important words and suppression of less important words.
In the following blog post I’m going to use some pre-trained transformer based systems to compare similarity of texts to allow lookup of information or similar text in a database - which can be used for example to aid building knowledge bases for generative pre-trained transformers used as chat bots, searching for similar papers in paper repositories, building local search engines to search for similar information instead of keywords, etc.
Since typical word embeddings are high dimensional objects (even though one performs dimensional reduction in the process) the main problem when building real life applications - assuming the transformer is already pre-trained and available (if it’s not the main problem is getting enough training data and computational resources) - is the indexing of huge datasets in high dimensional spaces. How to do this properly for large scale datasets will not be addressed in this blog article.
The word and sentence embeddings used here will be:
all-mpnet-base-v2
which is based on Microsoft’s mpnet-base
(based on 1B training pairs)Since this is just a quick glance the only similarity metric that I’m going to use is the cosine similarity:
[ \begin{aligned} cos(\alpha) = \frac{v_1 v_2}{\mid v_1 \mid \mid v_2 \mid} \end{aligned} ]Like mentioned earlier one can imagine this as a projection of one normalized vector on the other normalized vector - or also as the cosine of the angle between both vectors.
This is an commercial service. It may look silly to use a commercial service if quality of open embeddings is also more than sufficient or even comparable but under some circumstances like building cloud applications it makes actual sense - when you can use the API for example from Lambdas and your total amount of function calls is cheaper than running a machine 24/7 to host your embedding transformer to perform the lookups it may be more attractive and economic to use a commercial service. Also it scales better than a self hosted solution.
As of the time of writing or refurbishing this notebook there had been two main embedding algorithms available:
text-embedding-3-small
working in up to 1536 dimensions having a context window of 8191 tokenstext-embedding-3-large
working in up to 3072 dimensionsOne can also of course downscale the dimensional space (by setting the dimension
parameter in the request) as a trade off between vector size and accuracy.
OpenAIs embedding API is extremely simple to use using their OpenAI
Python client:
pip install openai
Initialization works as usual - import the library and create the client. If not set in the environment set the API key in the OPENAI_API_KEY
environment variable. In real applications this should either be provided through the environment or read from a configuration file. Make sure to never push this into a source repository!
from openai import OpenAI
import urllib.request
import os
import json
import numpy as np
# This key shoud usually be loaded from the environment
APIKEY="sk-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
os.environ["OPENAI_API_KEY"] = APIKEY
client = OpenAI()
So before doing anything useful let’s just take a look on how one can create embeddings using this API. It’s extremely simple - just pass a list of the sentences you want to embed as input
parameter to the client.embeddings.create
method and supply the desired model as model
.
res = client.embeddings.create(input = ["Hello world! I want to generate embedding vectors"], model="text-embedding-3-small") # .data[0].embedding
Lets create two more vectors - one for a similar text and one for an entirely different - and take a look at the length of such a vector and it’s content.
res2 = client.embeddings.create(input = ["Oh I know embedding vectors!"], model="text-embedding-3-small") # .data[0].embedding
res3 = client.embeddings.create(input = ["A totally different topic"], model="text-embedding-3-small") # .data[0].embedding
print(len(res.data[0].embedding))
print(res.data[0].embedding)
1536
[0.004372989293187857, -0.024218644946813583, 0.022598065435886383, 0.0065852077677845955, -0.04244372248649597, -0.0455305390059948, 0.02316398173570633, 0.027369769290089607, -0.025762053206562996, 0.008906750939786434, -8.9881410531234e-05, -0.016643082723021507, -0.0037299026735126972, -0.0561286062002182, 0.014533759094774723, -0.010694531723856926, -0.03133118525147438, 0.038173627108335495, 0.008951766416430473, 0.016771700233221054, 0.0355498306453228, -0.018045011907815933, -0.013877810910344124, 0.0111639853566885, 0.032823145389556885, -0.007787779904901981, 0.005260448902845383, 0.058392271399497986, 0.05576847866177559, -0.02348552644252777, -0.024951763451099396, -0.04954339936375618, -0.02357555739581585, -0.01299035083502531, -0.039819926023483276, 0.012070736847817898, 0.0008561091381125152, 0.025028934702277184, -0.03588423877954483, 0.02940192259848118, 0.023189706727862358, -0.008546622470021248, 0.036475878208875656, -0.00035610925988294184, -0.009344049729406834, 0.045607708394527435, -0.030353691428899765, -0.0029549833852797747, 0.002276526764035225, 0.011099676601588726, -0.04336976632475853, -0.02604501135647297, -0.021363340318202972, 0.04740834981203079, -0.031073948368430138, 0.012212215922772884, -0.014160769060254097, -0.020578773692250252, -0.035652726888656616, 0.014778131619095802, 0.042752403765916824, 0.0069581978023052216, 0.015472665429115295, 0.009594853036105633, 0.02565915882587433, 0.03138263151049614, -0.027909962460398674, 0.018726684153079987, -0.01320900022983551, 0.029479093849658966, -0.008836011402308941, 0.02354983426630497, -0.014379418455064297, 0.006266879849135876, 0.0016768485074862838, -0.017144691199064255, -0.05360770598053932, 0.0025980703067034483, -0.02393568679690361, -0.06240513175725937, -0.006604500114917755, 0.01693890243768692, -0.026173628866672516, -0.0011390673462301493, -0.04529902711510658, -0.025311892852187157, 0.003700963919982314, 0.0001710208598524332, -0.0052893878892064095, -0.02514468878507614, -0.025800637900829315, -0.0037749188486486673, -0.03145980089902878, -0.01606430485844612, 0.002530546160414815, 0.05900963395833969, 0.0018038582056760788, -0.018803855404257774, 0.012720254249870777, -0.015588421374559402, -0.011877811513841152, 0.025839222595095634, 0.02927330508828163, 0.020475879311561584, 0.02185208536684513, -0.0321800597012043, -0.0019999996293336153, -0.044836003333330154, -0.017594851553440094, -0.0005622989265248179, -0.061736322939395905, -0.04583922028541565, 0.005437297746539116, 0.0028279738035053015, 0.036141470074653625, -0.009189709089696407, -0.0034501601476222277, -0.00473954901099205, -0.009755625389516354, -0.008546622470021248, -0.015549835748970509, 0.008077168837189674, 0.04349838197231293, -0.007562699262052774, -0.04805143550038338, -0.03498391434550285, -0.05777490884065628, -0.018803855404257774, -0.006913181859999895, -0.019961411133408546, 0.004877812694758177, -0.006147908512502909, 0.020733114331960678, 0.017221860587596893, -0.0637427493929863, -0.04612217843532562, 0.002006430411711335, 0.04915754497051239, -0.02729259990155697, 0.05000641942024231, -0.01692604087293148, -0.02102893404662609, -0.002139871008694172, -0.037736326456069946, -0.02263665199279785, -0.013942119665443897, -0.028347261250019073, 0.015627006068825722, -0.02402571775019169, -0.015627006068825722, -0.00465273205190897, 0.00927974097430706, -0.05257876589894295, 0.02183922380208969, -0.010887457989156246, -0.03382635861635208, -0.011131830513477325, 0.007639870047569275, 0.03526687249541283, 0.02845015563070774, 0.03922828659415245, 0.02641800045967102, -0.0010860126931220293, -0.003929259721189737, -0.007890673354268074, 0.02060449682176113, -0.0017411572625860572, -0.022006426006555557, 0.009209001436829567, -0.07007072865962982, 0.04223793372511864, -0.025710605084896088, -0.015524112619459629, -0.028347261250019073, 0.029864946380257607, -0.008315110579133034, -0.016192922368645668, -0.058443717658519745, 0.0029855298344045877, -0.07269451767206192, -0.010681670159101486, 0.06477168947458267, -0.02144050970673561, 0.013311894610524178, -0.019215431064367294, 0.07825078815221786, 0.02191639505326748, 0.02185208536684513, 0.016694530844688416, 0.0637427493929863, -0.05335047096014023, 0.047639861702919006, 0.032025717198848724, -0.03441799804568291, 0.023884238675236702, 0.0031237935181707144, 0.05942121148109436, -0.007909966632723808, 0.030482308939099312, -0.03722185641527176, 0.009003213606774807, -0.01611575298011303, -0.04756269231438637, -0.026212213560938835, -0.004411574453115463, -0.03207716345787048, -0.007858519442379475, 0.01699035055935383, -0.04488744959235191, 0.014392280019819736, 0.021157551556825638, -0.012758839875459671, 0.011260448023676872, -0.014945334754884243, -0.0096977474167943, 0.028578773140907288, 0.0006724275299347937, 0.05160127580165863, 0.02101607248187065, -0.030019287019968033, -0.06441156566143036, 0.028810283169150352, 0.05149838328361511, -0.011980704963207245, -0.006035368423908949, -0.020128613337874413, 0.016591636463999748, -0.06204500421881676, 0.04861735180020332, 0.018675237894058228, 0.022520896047353745, 0.006553053390234709, 0.03701607137918472, -0.054790984839200974, -0.031974270939826965, 0.0091061070561409, 0.018379418179392815, 0.005813503637909889, -0.006353696342557669, 0.00476527214050293, -0.020385848358273506, -0.021337617188692093, 0.038225073367357254, 0.03236012160778046, 0.01857234351336956, 0.006327973213046789, -0.025749191641807556, -0.031562693417072296, 0.03812217712402344, 0.016437295824289322, -0.011241155676543713, 0.015627006068825722, 0.00950482115149498, 0.010430865921080112, 0.019099675118923187, 0.03688745200634003, -0.05972989276051521, 0.024090027436614037, 0.026122180745005608, 0.020540188997983932, 0.029659157618880272, -0.011652730405330658, -0.008302249014377594, 0.024154335260391235, 0.04133761301636696, -0.053916387259960175, -0.052295807749032974, -0.05870095267891884, -0.0018086812924593687, -0.04504179209470749, -0.002009645802900195, -0.04205786809325218, 0.03840513527393341, -0.039819926023483276, -0.05782635509967804, 0.012212215922772884, -0.009125400334596634, -0.030713820829987526, 0.007016075775027275, 0.024450154975056648, -0.029633434489369392, -0.0007094049942679703, 0.036115746945142746, 0.06554339826107025, -0.027086811140179634, -0.014829578809440136, -0.00527974171563983, -0.0011503213318064809, -0.051807064563035965, -0.03585851192474365, 0.05149838328361511, 0.0018102889880537987, 0.029170412570238113, -0.008578776381909847, 0.00515755545347929, 0.009395496919751167, 0.011099676601588726, 0.047768477350473404, 0.036990344524383545, -0.007980706170201302, 0.051832787692546844, -0.060038574039936066, 0.045993559062480927, 0.03511253371834755, -0.053144682198762894, 0.021144689992070198, 0.03210288658738136, 0.0019999996293336153, 0.06981348991394043, 6.003817179589532e-05, -0.03755626082420349, -5.441116445581429e-05, 0.012199354358017445, 0.01014790777117014, -0.025414785370230675, 0.03894532844424248, 0.019601281732320786, -0.02565915882587433, 0.017839225009083748, -0.044836003333330154, -0.022829577326774597, -0.01736333966255188, 0.01467523816972971, -0.060090020298957825, 0.004836011677980423, -0.01735047809779644, -0.005324757657945156, 0.015807071700692177, 0.04046301543712616, -0.02474597655236721, 0.0293761994689703, 0.027344046160578728, 0.050958190113306046, 0.007999998517334461, -0.008398711681365967, -0.04079741984605789, -0.018366556614637375, -0.000228094810154289, -0.011581990867853165, 0.020887454971671104, -0.05162699893116951, -0.006990352179855108, 0.038996778428554535, -0.0026961409021168947, -0.03145980089902878, 0.038688093423843384, 0.011292601935565472, 0.03580706566572189, 0.001935690874233842, -0.04092603549361229, 0.013929258100688457, 0.011691316030919552, 0.00527652632445097, -0.03148552402853966, 0.02523472160100937, -0.04138905927538872, 0.037762049585580826, -0.022945333272218704, -0.06641799211502075, -0.04180063307285309, -0.0177106074988842, 0.03920256346464157, 0.025491956621408463, -0.01470096129924059, 0.008540190756320953, -0.021170413121581078, 0.020951764658093452, 0.023318322375416756, 0.012533759698271751, 0.04679098725318909, -0.02932475320994854, 0.019729899242520332, -0.0025771697983145714, 0.0037781342398375273, 0.005411574617028236, -0.02394854836165905, 0.035318322479724884, 0.012874595820903778, 0.03140835464000702, 0.04169774055480957, 0.01401928998529911, -0.02060449682176113, -0.06317683309316635, -0.0528874509036541, -0.01385208684951067, 0.00443408265709877, 0.028244366869330406, -0.019691314548254013, 0.00014720656326971948, 0.03930545970797539, 0.01529260165989399, -0.017491957172751427, -0.03176848217844963, 0.030379414558410645, 0.056745968759059906, -0.010282956063747406, 0.007549837697297335, -0.01117684692144394, -0.011260448023676872, -0.014366556890308857, -0.018482310697436333, -0.05278455466032028, 0.0380450077354908, -0.0005337619222700596, -0.014546620659530163, -0.024514464661478996, 0.026572341099381447, 0.022881023585796356, 0.02554340288043022, 0.016270093619823456, -0.008405143395066261, 0.001304662087932229, 0.028810283169150352, -0.019536973908543587, -0.0028536971658468246, 0.025890670716762543, 3.770598050323315e-05, -0.03526687249541283, -0.005144693423062563, 0.03719613328576088, -0.04720256105065346, -0.02306108921766281, -0.034109316766262054, 0.047254011034965515, 0.01736333966255188, -0.051035359501838684, 0.04841156676411629, 0.0010739547433331609, -0.0040225074626505375, -0.009247586131095886, -0.01448231190443039, -0.005115754436701536, 0.018366556614637375, -0.013530544005334377, -0.04908037558197975, -0.04205786809325218, 0.052681662142276764, 0.01731189340353012, 0.03763343393802643, 0.020514465868473053, 0.029607711359858513, 0.01780064031481743, -0.053453367203474045, -0.017003212124109268, -0.007852088660001755, -0.04547908902168274, -0.021376201882958412, 0.032411567866802216, -0.020141474902629852, -0.01654019020497799, 0.01509967539459467, -0.05751767382025719, 0.02605787292122841, -0.013350479304790497, 0.010096460580825806, -0.048591628670692444, -3.2330179237760603e-05, -0.00505466153845191, 0.031974270939826965, 0.01978134736418724, -0.0016286170575767756, -0.027035364881157875, -0.03524114936590195, 0.046919602900743484, 0.02186494693160057, -0.012591637670993805, -0.020784562453627586, -0.010508036240935326, 0.04504179209470749, -0.022495172917842865, -0.011048229411244392, 0.010019290260970592, 0.01278456300497055, -0.01985851675271988, 0.0036752403248101473, 0.0061897095292806625, 0.016437295824289322, -0.0305594801902771, -0.036501601338386536, 0.011196139268577099, 0.02354983426630497, -0.006964629050344229, -0.02641800045967102, 0.014739546924829483, -0.00527009554207325, 0.1042829379439354, -0.014405141584575176, -0.0436527244746685, 0.027498386800289154, -0.03356912359595299, -0.04884886369109154, 2.0686791685875505e-05, 0.015176845714449883, -0.014366556890308857, 0.0348038524389267, -0.009864949621260166, -0.045659154653549194, -0.006363342981785536, 0.0064662364311516285, -0.025003209710121155, 0.004356912337243557, -0.003954983316361904, -0.012906749732792377, -0.01691317930817604, 0.03215433284640312, -0.006919612642377615, -0.02353697270154953, 0.0077620563097298145, -0.004138262942433357, 0.01903536543250084, -0.015434080734848976, -0.0007017683237791061, -0.025324754416942596, -0.003951767925173044, 0.027009641751646996, 0.011993566527962685, -0.014135044999420643, 0.04298391565680504, -0.021954979747533798, -0.01936977170407772, 0.02559485100209713, -0.014932473190128803, -0.018868163228034973, -0.03428938239812851, -0.009607715532183647, -0.007421220187097788, -0.04077169671654701, -0.02891317754983902, -0.01551125105470419, -0.00953054428100586, 0.011260448023676872, -0.027472663670778275, -0.00026266073109582067, 0.0017845655092969537, -0.04676526412367821, -0.010855303145945072, -0.009228293783962727, 0.05371059849858284, 0.023189706727862358, 0.016424434259533882, -0.03333761543035507, 0.03264307975769043, 0.02760128118097782, 0.0050996774807572365, -0.012411572970449924, 0.023305460810661316, -0.01645015738904476, -0.007601284421980381, -0.015202568843960762, 0.05080384761095047, -0.001618970767594874, -0.010906750336289406, 0.0005180867156013846, -0.03094533085823059, 0.006009645294398069, 0.0019726683385670185, -0.003491960698738694, -0.010475882329046726, -0.019176844507455826, 0.027163982391357422, 0.019215431064367294, 0.00432154256850481, 0.05131831765174866, 0.0587523989379406, -0.04002571478486061, 0.06837297976016998, -0.002244372619315982, -0.027807068079710007, 0.012006428092718124, -0.06559484452009201, -0.03645015507936478, -0.049774911254644394, 0.04218648746609688, 0.006816718727350235, -0.010784563608467579, 0.01773633062839508, -0.005045014899224043, 0.0025675236247479916, -0.03678455948829651, 0.020463017746806145, -0.00548552954569459, 0.001745980349369347, 0.012591637670993805, -0.027987131848931313, -0.02140192501246929, 0.02109324373304844, 0.001431671786122024, 0.004999998956918716, 0.0034501601476222277, -0.011781347915530205, 0.02809002622961998, 0.00031953369034454226, -0.0024389063473790884, -0.007942120544612408, -0.009035367518663406, -0.02526044473052025, 0.017466234043240547, -0.020437294617295265, -0.0156527291983366, -0.022392278537154198, 0.0404115691781044, 0.020424433052539825, -0.01692604087293148, -0.024643082171678543, 0.02099034935235977, 0.010372987948358059, -0.010881027206778526, 0.04280385002493858, 0.0003092845145147294, -0.01656591333448887, -0.029993563890457153, -0.01738906465470791, 0.0006024918402545154, -0.005237941164523363, 0.010617361404001713, 0.0007303053280338645, -0.007742763496935368, 0.015575559809803963, 0.006913181859999895, 0.0016141475643962622, 0.00931189488619566, -0.043163977563381195, -0.029864946380257607, 0.008739547803997993, -0.015781346708536148, 0.012302248738706112, 0.007903535850346088, -0.0289389006793499, -0.00485208909958601, 0.04257233813405037, -0.012623791582882404, 0.0013754016254097223, 0.014829578809440136, -0.016745977103710175, -0.0086366543546319, -0.011852087453007698, -0.01508681382983923, 0.0016575559275224805, 0.018726684153079987, -0.03627008944749832, 0.013311894610524178, -0.01654019020497799, 0.04036011919379234, -0.015395495109260082, -0.0016495173331350088, 0.014726685360074043, -0.01935691013932228, 0.030430862680077553, 0.008102891966700554, 0.0031141473446041346, -0.01736333966255188, -0.002299034968018532, 0.002766880439594388, -0.025376200675964355, 0.04499034583568573, -0.006836011540144682, 0.021363340318202972, -0.02514468878507614, -0.0006241960218176246, -0.008964627981185913, -0.005495175719261169, 0.034057870507240295, -0.01400642842054367, -0.03133118525147438, 0.016655944287776947, 0.019292600452899933, -0.0355498306453228, 0.02973632887005806, -0.02567202039062977, 0.0035401922650635242, 0.011196139268577099, 0.02932475320994854, -0.004729902371764183, -0.009594853036105633, 0.011401927098631859, -0.044836003333330154, -0.022070735692977905, 0.02852732501924038, 0.022456586360931396, -0.02390996366739273, -0.011054660193622112, 0.022610927000641823, -0.017234724014997482, -0.006257233675569296, 0.00475241057574749, 0.027961408719420433, 0.011466235853731632, 0.0028906746301800013, 0.004180063493549824, -0.014327971264719963, -0.042778126895427704, -0.011787778697907925, -0.012135045602917671, 0.01645015738904476, 0.011331187561154366, 0.038636647164821625, -0.015807071700692177, -0.004183278884738684, -0.006964629050344229, -0.03254018723964691, 0.001486334134824574, 0.0005192924872972071, -0.002429259940981865, 0.00463343970477581, -0.007382635027170181, -0.010443727485835552, 0.004874597303569317, 0.012610930018126965, 0.04794854298233986, 0.060501594096422195, -0.024180060252547264, -0.023009641095995903, -0.00442443648353219, -0.002845658687874675, -0.04015433415770531, 0.003321542637422681, 0.0177749153226614, -0.007729901932179928, -0.014173630625009537, -0.021543404087424278, -0.0362958125770092, 0.0012668807758018374, -0.03663021698594093, -0.03462378680706024, -0.022070735692977905, -0.032437290996313095, -0.015421219170093536, -0.020810285583138466, 0.022610927000641823, 6.214829772943631e-05, -0.0007395496941171587, 0.01988423988223076, 0.021273307502269745, 0.011401927098631859, 0.02182636223733425, -0.0047041792422533035, -0.031099671497941017, -0.0037974268198013306, -0.021221861243247986, -0.029093241319060326, -0.05034082755446434, 0.03215433284640312, 0.004160770680755377, 0.047331180423498154, -0.01894533447921276, 0.0169003177434206, 0.03663021698594093, -0.011427650228142738, 0.024038581177592278, -0.009652731008827686, -0.03058520331978798, 0.00906109157949686, 0.011286171153187752, -0.012038582935929298, -0.016231507062911987, -0.020758837461471558, 0.06158198043704033, -0.03958841785788536, -0.01445658877491951, 0.01383922528475523, 0.027575558051466942, 0.033286165446043015, -0.0064662364311516285, 0.042160764336586, 0.01567845419049263, -0.017954980954527855, 0.010778132826089859, 0.016321539878845215, 0.018276523798704147, 0.033646296709775925, 0.008598068729043007, 0.002948552370071411, 0.01448231190443039, 0.002056269673630595, 0.050958190113306046, -0.0397942028939724, -0.009022505953907967, -0.007189709227532148, -0.058032143861055374, 0.03349195420742035, -0.006739548407495022, 9.892482194118202e-05, -0.010308679193258286, -0.05535690113902092, -0.020411571487784386, -0.0012025721371173859, -0.012122184038162231, 0.013260447420179844, 0.022032149136066437, -0.06101606413722038, -0.03212860971689224, -0.0036913175135850906, 0.014816717244684696, -0.012385849840939045, 0.0007375400164164603, 0.03989709913730621, -0.01571703888475895, -0.01900964230298996, -0.021208999678492546, 0.010790994390845299, -0.01445658877491951, 0.016784561797976494, 0.010938904248178005, 0.04496462270617485, 0.0069581978023052216, -0.025684881955385208, 0.0004509645514190197, -0.0010024113580584526, -0.027935685589909554, 0.03207716345787048, 0.000782958057243377, 0.039022501558065414, -0.04630224034190178, 0.0029019287321716547, -0.026302246376872063, -0.046482305973768234, 0.003765272442251444, -0.0035112532787024975, 0.02309967391192913, 0.03837941214442253, 0.010572344996035099, -0.01117684692144394, -0.04015433415770531, 0.0033344044350087643, 0.022327970713377, -0.014520897530019283, 0.012617360800504684, -0.01979420892894268, 2.149065039702691e-05, 0.007549837697297335, 0.02721542865037918, 0.0008135046809911728, 0.005964628886431456, 0.007434082217514515, 0.012861733324825764, 0.0643601194024086, 0.017041796818375587, 0.004562700167298317, -0.040205780416727066, -0.0009228293783962727, -0.014366556890308857, 0.0189581960439682, 0.012752409093081951, 0.051395487040281296, 0.009762056171894073, -0.005254018120467663, -0.010289386846125126, 0.013954981230199337, 0.013453373685479164, 0.006823149975389242, -0.0330546572804451, -0.001380224828608334, 0.009369772858917713, 0.010218647308647633, -0.00433761952444911, 0.03264307975769043, 0.005401927977800369, -0.0038199350237846375, 0.00012178454198874533, 0.0007998390938155353, -0.02473311312496662, 0.005106108263134956, -0.0010892280843108892, 0.012591637670993805, 0.030790990218520164, -0.049337610602378845, -0.007569130510091782, -0.013466235250234604, 0.00515755545347929, 0.01015433855354786, 0.02721542865037918, -0.002141478704288602, -0.004987137392163277, 0.039434075355529785, 0.052244361490011215, 0.029967838898301125, -0.02852732501924038, 0.025337615981698036, -0.013762054964900017, -0.019909963011741638, 0.016823148354887962, -0.020167198032140732, 0.02935047633945942, 0.004189709667116404, -0.05602571368217468, -0.029607711359858513, 0.018482310697436333, 0.018276523798704147, -0.040643077343702316, -0.027421215549111366, 0.020836008712649345, 0.004147909115999937, -0.017839225009083748, -0.011048229411244392, 0.02018005959689617, 0.05062378570437431, -0.011511251330375671, 0.08216075599193573, 0.0023054657503962517, -0.013273308984935284, 0.011832795105874538, -0.05808359012007713, 0.03377491235733032, 0.014829578809440136, -0.0019437294686213136, 0.01603858172893524, 0.008340834639966488, 0.009434081614017487, 0.018418002873659134, -0.04429581016302109, -0.0006057072896510363, -0.04208359122276306, 0.022418001666665077, 0.004832796286791563, 0.057208992540836334, 0.024180060252547264, -0.0032491954043507576, 0.04455304518342018, 0.0057974266819655895, 0.0289389006793499, -0.028295814990997314, -0.020887454971671104, -0.03187137469649315, 0.010900319553911686, -0.012668807990849018, -0.004678455647081137, -0.04969773814082146, 0.011202570050954819, 0.02397427149116993, 0.02105465903878212, -0.01341478805989027, -0.01773633062839508, 0.0092604486271739, 0.008218647912144661, -0.001903536613099277, -0.07876525819301605, -0.025903532281517982, 0.014906749129295349, -0.018662376329302788, -0.01812218315899372, -0.01947266422212124, 0.0009155946900136769, 0.024565910920500755, -0.024938901886343956, 0.03603857755661011, 0.001935690874233842, 0.005736333318054676, -0.0036205779761075974, 0.01810932159423828, -0.04144050553441048, 0.03349195420742035, 0.021954979747533798, 0.011530544608831406, 0.04409002512693405, -0.01819935254752636, 0.0011020897654816508, 0.05705465003848076, -0.019138259813189507, -0.016784561797976494, -0.008784564211964607, 0.018353693187236786, 0.00967845506966114, 0.010096460580825806, 0.02189067006111145, -0.008173631504178047, 0.00464630126953125, -0.0020192922092974186, -5.9636240621330217e-05, 0.011536975391209126, 0.011839225888252258, 0.02315112017095089, 0.007594853639602661, 0.03753053769469261, 0.0029533756896853447, 0.03225722908973694, 0.01551125105470419, -0.033646296709775925, 0.04429581016302109, 0.03886815905570984, -0.023215429857373238, 0.01988423988223076, -0.027061088010668755, 0.0182122141122818, -0.020141474902629852, -0.008823148906230927, 0.002638263162225485, 0.021620575338602066, -0.021762054413557053, -0.00046382626169361174, -0.013260447420179844, 0.018418002873659134, 0.01549838948994875, -0.026675235480070114, 0.00926687940955162, 0.028372984379529953, -0.02598070167005062, -0.027935685589909554, -0.025749191641807556, -0.039819926023483276, -0.0027331183664500713, 0.009864949621260166, -0.002638263162225485, -0.017131829634308815, 0.018893886357545853, 0.020823147147893906, 0.031176842749118805, -0.011672023683786392, 0.02317684330046177, 0.004083600360900164, 0.010186493396759033, 0.013016074895858765, -0.025054657831788063, -0.01773633062839508, -0.003807073226198554, 0.012745978310704231, 0.012546621263027191, 0.005736333318054676, -0.016282955184578896, 0.035318322479724884, 0.020861731842160225, -0.007607715670019388, -0.03753053769469261, -0.017569128423929214, 0.017569128423929214, -0.006045015063136816, -0.0021302246022969484, 0.012308679521083832, -0.0521414689719677, -0.020527327433228493, 0.032874591648578644, -0.023768484592437744, 0.021954979747533798, -0.0177106074988842, -0.017080381512641907, -0.0389196053147316, 0.03418648988008499, 0.005299034528434277, -0.004935690201818943, 0.0014734723372384906, 0.03593568503856659, -0.016218645498156548, -0.03917684033513069, -0.019678452983498573, 0.024038581177592278, 0.006157555151730776, -0.002815112005919218, -0.014379418455064297, 0.01470096129924059, -0.001072347047738731, 0.02030867710709572, -0.01034083403646946, 0.02604501135647297, -0.005363342817872763, -0.0038971053436398506, 0.024231506511569023, 0.037813495844602585, 0.011408357881009579, -0.0016342439921572804, -0.017659161239862442, 0.026186490431427956, -0.06225079298019409, -0.003427651943638921, -0.05726043879985809, 0.03544693812727928, -0.015588421374559402, -0.027163982391357422, 0.015485526993870735, -0.004450160078704357, 0.027061088010668755, 0.004832796286791563, -0.018700961023569107, 0.0012130222748965025, -0.01297748927026987, 0.01605144329369068, -0.019331185147166252, 0.013260447420179844, 0.015575559809803963, -0.04043729230761528, -0.008578776381909847, 0.0014951765770092607, -0.01362057588994503, 0.0019421217730268836, 0.011665592901408672, -0.00926687940955162, 0.002297427272424102, 0.012225077487528324, -0.002874597441405058, 0.022430863231420517, 0.017247585579752922, 0.01609002985060215, 4.147406798438169e-05, -0.0050803846679627895, 0.018893886357545853, -0.020913179963827133, -0.018237939104437828, 0.014585206285119057, 0.020424433052539825, 0.025903532281517982, -0.0029839221388101578, -0.006784564349800348, -0.01385208684951067, 0.019125398248434067, 0.022855300456285477, 0.0019646298605948687, 0.010012859478592873, 0.010617361404001713, -0.03236012160778046, 0.0017282954649999738, -0.002996783936396241, 0.007864950224757195, -0.014109321869909763, 0.04331832006573677, -0.005209002178162336, -0.019228292629122734, -0.01074597891420126, 0.0020160768181085587, 0.004511252976953983, -0.035292595624923706, 0.006527329795062542, -0.027729898691177368, 0.022829577326774597, 0.00931832566857338, -0.003186494577676058, 0.009202570654451847, 0.0035819928161799908, -0.002877812832593918, 0.03341478481888771, 0.0293761994689703, 0.0026157551910728216, -0.019138259813189507, -0.010842441581189632, -0.029967838898301125, 0.024180060252547264, -0.021376201882958412, 0.006643085274845362, -0.0037427644710987806, -0.014096460305154324, -0.013273308984935284, -0.01466237660497427, 0.004045015200972557, 0.04547908902168274, 0.005324757657945156, -0.008199355565011501, -0.01976848393678665, 0.03444372117519379, 0.01529260165989399, 0.02308681234717369, 0.08360126614570618, -0.013157553970813751, -0.01530546322464943, -0.0032299028243869543, 0.007832796312868595, -0.009202570654451847, -0.011530544608831406, 0.029144689440727234, -0.028836006298661232, -0.0010024113580584526, -0.005221863742917776, -0.017826363444328308, -0.010636653751134872, 0.0035337612498551607, 0.0034565909299999475, 0.0011503213318064809, 0.010823149234056473, -0.030019287019968033, 0.010520897805690765, -0.013221862725913525, 0.027987131848931313, -0.026675235480070114, -0.00905466079711914, 0.019254015758633614, -0.029453370720148087, 0.01055948343127966, -0.0397942028939724, 0.030405137687921524, 0.010411573573946953, 0.02721542865037918, -0.00041559478268027306, -0.01740192621946335, 0.01648874208331108, -0.014340832829475403, 0.020887454971671104, 0.0251318272203207, -0.03477812930941582, -0.006604500114917755, -0.020540188997983932, 0.026237936690449715, -0.009202570654451847, -0.000625401793513447, -0.009594853036105633, -0.016745977103710175, 0.002844050759449601, -0.020578773692250252, 0.004183278884738684, 0.028244366869330406, 0.025774914771318436, -0.03377491235733032, -0.00464308587834239, 0.07053374499082565, -0.024540187790989876, 0.004372989293187857, 0.010913181118667126, -0.012167200446128845, -0.00803858321160078, 0.023434078320860863, 0.028810283169150352, -0.01781350187957287, 0.01778777688741684, 0.009961413219571114, -0.04594211280345917, 0.03259163349866867, 0.02803857997059822, -0.013762054964900017, -0.01735047809779644, 0.019073951989412308, 0.021954979747533798, 0.04082314297556877, -0.002612539567053318, 0.017414787784218788, -0.01976848393678665, -0.010186493396759033, 0.004681671038269997, 0.036527324467897415, 0.0035627002362161875, -0.001747588044963777, -0.03807073086500168, 0.04051446169614792, -0.020874593406915665, -0.025093242526054382, -0.011678454466164112, 0.03632153570652008, 0.011742763221263885, 0.012675238773226738, 0.021581988781690598, 0.03377491235733032, 0.013993565924465656, 0.01810932159423828, -0.041517674922943115, -0.017659161239862442, 0.024051442742347717, 0.00473633361980319, -0.008475882932543755, 0.004488745238631964, -0.00412861630320549, -0.006369773764163256, 0.011639868840575218, 0.03434082865715027, 0.011736332438886166, 0.031125396490097046, 0.016836009919643402, 0.029479093849658966, -0.003135047620162368, -0.015125398524105549, -0.04090031236410141, -0.008630223572254181, -0.014610929414629936, -0.021080382168293, 0.034855298697948456, -0.013466235250234604, -0.0008159162243828177, 0.0033472662325948477, 0.010070737451314926, -0.016810286790132523, -0.02439870871603489, 0.025016071274876595, 0.006887458264827728, -0.02554340288043022, 0.03683600574731827, 0.026520894840359688, -0.00453376118093729, 0.005900320131331682, 0.07156268507242203, -0.025929255411028862, 0.034083593636751175, 0.00473311822861433, 0.02929903008043766, -0.0007459805347025394, 0.012668807990849018, 0.009781348519027233, 0.0059035359881818295, 0.011665592901408672, 0.037762049585580826, 0.0013416395522654057, -0.01696462742984295, -0.006093246396631002, -0.0177106074988842, -0.00444694422185421, -0.013967842794954777, -0.029067518189549446, 0.01768488436937332, -0.0011414788896217942, -0.0011197746498510242, -0.01737620308995247, 0.02811574935913086, 0.00548552954569459, -0.016823148354887962, -0.018070735037326813, -0.0028874592389911413, -0.027524109929800034, 0.02307395078241825, -0.008424435742199421, 0.025427646934986115, -0.019524112343788147, -0.003681671340018511, 0.005540191661566496, 0.011620576493442059, -0.04133761301636696, 0.02018005959689617, 0.00040353689109906554, 0.006514468230307102, -0.011125399731099606, 0.023241152986884117, -0.02803857997059822, -0.01612861454486847, -0.021723467856645584, -0.02357555739581585, 0.010578775778412819, -0.011581990867853165, -0.004199355840682983, -0.02978777512907982, 0.015125398524105549, -0.007054660934954882, -0.017839225009083748, 0.02811574935913086, 0.011286171153187752, 0.0092926025390625, -0.0064823138527572155, -0.025697743520140648, -0.007408358622342348, -0.02978777512907982, -0.009446943178772926, 0.0022395492997020483, 0.01987137831747532, -0.015228292904794216, -0.00527009554207325, -0.011446942575275898, -0.020115751773118973, -0.010398712009191513, -0.050212208181619644, -0.01571703888475895, -0.010025721043348312, 0.00884887296706438, -0.020450156182050705, -0.011491958983242512, -0.027755621820688248, -0.0059839216992259026, 0.02517041377723217, -0.0018842440331354737, -0.009665592573583126, 0.02028295397758484, 0.00029662373708561063, 0.011080383323132992, 0.01854662038385868, 0.00973633211106062, 0.022186491638422012, 0.025530541315674782, -0.02803857997059822, 0.06075882911682129, -0.012186492793262005, 0.026520894840359688, -0.015138261020183563, 0.029170412570238113, -0.0030112534295767546, 0.009112538769841194, -0.015549835748970509, 0.01862378977239132, -0.017093244940042496, 0.011723469942808151, 0.026520894840359688, -0.010405142791569233, -0.019935688003897667, -0.04661092162132263, 0.03439227491617203, 0.018700961023569107, -0.025800637900829315, 0.013363341800868511, 0.029993563890457153, -0.02565915882587433, 0.005736333318054676, -0.010662376880645752, 0.028270089998841286, -0.017479095607995987, 0.01854662038385868, -0.015871379524469376, 0.03701607137918472, 0.01651446521282196, 0.003495176089927554, -0.027935685589909554, -0.011016074568033218, 0.02801285684108734, 0.029659157618880272, -0.01605144329369068, 0.052604492753744125, 0.0020964625291526318, 0.019935688003897667, -0.016347263008356094, -0.006636654492467642, -0.007877811789512634, 0.04450159892439842, -0.03441799804568291, -0.0008705786312930286, -0.029144689440727234, 0.023446939885616302, -0.0006237940979190171, 0.010385850444436073, -0.019549835473299026, 0.005411574617028236, -0.014495174400508404, 0.013736331835389137, 0.04336976632475853, -0.03709324076771736, -0.0009445335599593818, -0.023189706727862358, 0.022881023585796356, 0.024926040321588516, -0.01445658877491951, 0.024591634050011635, 0.04550481215119362, -0.020424433052539825, 0.016360124573111534, -0.03634725883603096, -0.020051442086696625, 0.05360770598053932, 0.019138259813189507, -0.022906748577952385, 0.01976848393678665, -0.019228292629122734, -0.017993565648794174, -0.003030546009540558, -0.008784564211964607, 0.017581989988684654, 0.02636655420064926, -0.02113182842731476, -0.021144689992070198, 0.007549837697297335, 0.004620577674359083, -0.00865594670176506, 0.04385851323604584, -0.012302248738706112, 0.011633438058197498, 0.017594851553440094, -0.03179420530796051, 0.013234724290668964, -0.006700963247567415, 0.03176848217844963, -0.026340831071138382, 0.018057873472571373, -0.01425080094486475, 0.011318325996398926, 0.03310610353946686, -0.025749191641807556, 0.0006149516557343304, 0.0436527244746685, -0.017131829634308815, 0.011414788663387299, -0.042752403765916824, 0.02317684330046177, -0.014842440374195576, -0.024192921817302704, 0.030662372708320618, -0.011517682112753391, 0.002030546311289072, 0.00035269284853711724, -0.01509967539459467, -0.003572346642613411, 0.011614145711064339, 0.04511896148324013, -0.007041799370199442, 0.004450160078704357, 0.027421215549111366, 0.043549831956624985, 0.04535047337412834, -0.06430866569280624, -0.0007729097851552069, -0.01893247291445732, 0.026135042309761047, 0.006913181859999895, 0.027344046160578728, 0.01034083403646946, 0.017890671268105507, -0.0099163968116045, 0.008829580619931221, -0.03513825684785843, 0.012913180515170097, 0.010836010798811913, 0.027163982391357422, -0.04980063438415527, 0.006990352179855108, 0.04535047337412834, 0.0028344045858830214, -0.012327971868216991, -0.013659161515533924, 0.006842442322522402, -0.024591634050011635, -0.02726687490940094, 0.0025080381892621517, -0.0023553050123155117, 0.020025718957185745, 0.031125396490097046, 0.0387909896671772, -0.01012218464165926, -0.02847587876021862, 0.01906108856201172, -0.01486816443502903, -0.010083599016070366, 0.007903535850346088, -0.03593568503856659, -0.03593568503856659, -0.01692604087293148, -0.0069581978023052216, -0.025491956621408463, -0.00993568915873766, 0.02803857997059822, -0.04074597358703613, -0.01443086564540863, -0.003051446285098791, -0.037299029529094696, 0.017607713118195534, 0.031099671497941017, 0.03590996190905571, 0.0008963020518422127, -0.005337619688361883, -0.05895818769931793, -0.026932470500469208, -0.03771060332655907, 0.01693890243768692, -0.016257232055068016, 0.0030096457339823246, -0.022610927000641823, -0.013543405570089817, 0.010276525281369686, 0.026109319180250168, 0.011202570050954819, 8.495777728967369e-05, -0.004350481554865837, 0.010225078091025352]
As one can see the small model has returned a 1536 dimensional vector in continuous vector space ($\mathfrak{R}^{1535}$) representing the text.
To get a quick idea on how to use those vectors to compare text similarity we’re going to get a quick glance on their cosine similarity - in this case let’s calculate a similarity matrix that compares text 1 with text 2 and text 3, text 2 with 1 and 3 and text 3 with 1 and 2. We’re just using plain numpy
methods for this:
chatgptEmbedding3Small = [
np.asarray(res.data[0].embedding),
np.asarray(res2.data[0].embedding),
np.asarray(res3.data[0].embedding)
]
# Quick cosine similarity ...
chatgptEmbedding3Small_Similarities = np.empty((len(chatgptEmbedding3Small), len(chatgptEmbedding3Small)))
for i in range(len(chatgptEmbedding3Small)):
for j in range(len(chatgptEmbedding3Small)):
chatgptEmbedding3Small_Similarities[i,j] = np.dot(chatgptEmbedding3Small[i], chatgptEmbedding3Small[j]) / (np.linalg.norm(chatgptEmbedding3Small[i]) * np.linalg.norm(chatgptEmbedding3Small[j]))
print(chatgptEmbedding3Small_Similarities)
[[1. 0.63782866 0.14981901]
[0.63782866 1. 0.22271813]
[0.14981901 0.22271813 1. ]]
As one can see any sentence projected on itself is of course perfectly similar (1) and the two sentences on embeddings share a high similarity whereas the third one has significant lower similarity.
Now let’s take a look how one could use BERT for word embeddings. This is extremely simple using the HuggingFace transformers
package. Note this usually also requires either Tensorflow
or torch
to be installed.
When using FreeBSD keep in mind that you should install packages like numpy
and torch
from the packages or ports. Using pip
won’t work for those packages since they require patches for some Linux specific stuff in their source code unfortunately:
# For FreeBSD install torch from packages or ports, not via PIP!
pkg install py-pytorch
The transformers package can simply be installed via pip
as usual:
# Install PIP packages
pip install transformers
Now let’s generate a word embedding vector using BERT - to tokenize we’re also going to use the BertTokenizer
. This separates a chain of words into a list of tokens that will then be transformed using BERT:
from transformers import BertModel, BertTokenizer
import torch
model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
First we have to tokenize the sentence:
text = "This is an example sentence."
tokens = tokenizer.tokenize(text)
print(tokens)
['this', 'is', 'an', 'example', 'sentence', '.']
Then we can transform the tokens into IDs using a dictionary. Each of the IDs is unique for each of the words.
input_ids = tokenizer.convert_tokens_to_ids(tokens)
print(input_ids)
[2023, 2003, 2019, 2742, 6251, 1012]
Now that we have the IDs we can use the transformer to generate the word embedding vectors. In addition to simply identifying the word they also encode the context in which the word has been used. Each word will get it’s own 768 dimensional vector assigned by BERT:
input_ids = torch.tensor(input_ids).unsqueeze(0)
with torch.no_grad():
outputs = model(input_ids)
embeddings = outputs.last_hidden_state[0]
print(embeddings)
tensor([[-1.8824e-01, -7.6458e-04, 1.0336e-01, ..., -2.0809e-01,
-3.9280e-01, 7.9072e-01],
[-4.1441e-01, -1.7607e-01, 5.4728e-02, ..., -1.0659e-01,
-3.8406e-01, 7.9451e-01],
[-5.5160e-01, 1.7656e-01, 2.4592e-01, ..., 1.5593e-01,
-5.1120e-01, 1.3524e+00],
[-2.6705e-01, 2.0308e-01, -3.6435e-02, ..., -9.6218e-02,
-5.8836e-01, 6.6819e-01],
[-1.7557e-01, 1.8462e-01, 3.5970e-02, ..., -1.4965e-01,
-3.1363e-01, 6.9866e-01],
[-1.6752e-01, -3.2122e-01, 7.4659e-02, ..., -2.1811e-01,
-3.7288e-01, 7.0560e-01]])
print(embeddings.size())
torch.Size([6, 768])
Ok those have been word embeddings which might be interesting to do statistics or search for synonyms. But for search purposes sentence embeddings are way more attractive and interesting. One can of course use statistics on word embeddings - but using HuggingFace SentenceTransformer
with sentence embedding networks is even more simple - just like the usage of OpenAIs API:
from sentence_transformers import SentenceTransformer
First let’s try to use all-MiniLM-L12-v2
. A bunch of interesting networks would be:
all-MiniLM-L12-v2
all-mpnet-base-v2
paraphrase-MiniLM-L3-v2
paraphrase-multilingual-MiniLM-L12-v2
Let’s initialize the model - on the first run this will download the pre-trained network files from a central repository (this is a step that I usually don’t like - it makes life easier but you’re dependent on the repository). I personally prefer to download them and keep them together with source code in my repositories - but keep in mind the files can be pretty large (on the order of 512 MB to more than a GByte).
model = SentenceTransformer('all-MiniLM-L12-v2')
The usage is extremely simple: Pass the encode function either a sentence or an array of sentences and the wrapper will take care of how to use the transformer to generate the sentence embeddings:
sentences = [
"Hello world! I want to generate embedding vectors",
"Oh I know embedding vectors!",
"A totally different topic"
]
embeddings = model.encode(sentences)
Taking a quick look reveals the vectors look unsurprisingly similar to the vectors returned by OpenAI’s embedding API before:
print(embeddings)
[[-0.06076973 -0.12821947 -0.00794378 ... 0.00089256 0.04703912
-0.06429676]
[ 0.00551554 -0.11793077 0.00410863 ... -0.03375708 0.09271023
-0.02079425]
[ 0.00426334 0.08150572 -0.01534932 ... -0.00426619 0.00174925
0.07479197]]
One has to keep in mind though that every model projects into an entire different vector space. Even in case the dimensions would match one cannot use vectors from different models in the same database. If you swap models you have to recalculate vectors for all of your data and rebuild all indices.
Now as for OpenAIs API above let’s run a quick similarity check using the cosine similarity:
# Quick cosine similarity ...
embeddings_MiniLM_L12_v2 = [
np.asarray(embeddings[0]),
np.asarray(embeddings[1]),
np.asarray(embeddings[2])
]
similarities_MiniLM_L12_v2 = np.empty((len(embeddings_MiniLM_L12_v2), len(embeddings_MiniLM_L12_v2)))
for i in range(len(embeddings_MiniLM_L12_v2)):
for j in range(len(embeddings_MiniLM_L12_v2)):
similarities_MiniLM_L12_v2[i,j] = np.dot(embeddings_MiniLM_L12_v2[i], embeddings_MiniLM_L12_v2[j]) / (np.linalg.norm(embeddings_MiniLM_L12_v2[i]) * np.linalg.norm(embeddings_MiniLM_L12_v2[j]))
print(similarities_MiniLM_L12_v2)
[[ 1.00000012 0.7471174 -0.01569488]
[ 0.7471174 1.00000012 0.04449366]
[-0.01569488 0.04449366 1. ]]
print(embeddings_MiniLM_L12_v2[0].shape)
(384,)
As we can see this also relates similar sentences in our sample extremely well - even with the small vector size.
Since the model used before has been a smaller one let’s take a look how one could switch to a more complex and higher dimensional model. Nothing easier than that - just change the model parameter. In this case it will be all-mpnet-base-v2
which will return a 768 dimensional vector to describe the context of the sentence:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')
sentences2 = [
"Hello world! I want to generate embedding vectors",
"Oh I know embedding vectors!",
"A totally different topic"
]
embeddings2 = model.encode(sentences2)
print(embeddings2)
[[-0.0161423 -0.05080677 -0.01390223 ... 0.01687648 -0.05224716
-0.03931584]
[-0.02392078 -0.04040321 -0.02289272 ... 0.03996665 -0.01487556
-0.06380195]
[ 0.04085479 0.01936843 -0.01162221 ... -0.00914402 -0.00163121
-0.0032566 ]]
# Quick cosine similarity ...
embeddings_all_mpnet_base_v2 = [
np.asarray(embeddings2[0]),
np.asarray(embeddings2[1]),
np.asarray(embeddings2[2])
]
similarities_all_mpnet_base_v2 = np.empty((len(embeddings_all_mpnet_base_v2), len(embeddings_all_mpnet_base_v2)))
for i in range(len(embeddings_all_mpnet_base_v2)):
for j in range(len(embeddings_all_mpnet_base_v2)):
similarities_all_mpnet_base_v2[i,j] = np.dot(embeddings_all_mpnet_base_v2[i], embeddings_all_mpnet_base_v2[j]) / (np.linalg.norm(embeddings_all_mpnet_base_v2[i]) * np.linalg.norm(embeddings_all_mpnet_base_v2[j]))
print(similarities_all_mpnet_base_v2)
[[1. 0.73634386 0.02250435]
[0.73634386 1. 0.19576186]
[0.02250435 0.19576186 1. ]]
print(embeddings_all_mpnet_base_v2[0].shape)
(768,)
As one can see also this embedding - unsurprisingly again - identifies similar content pretty well.
Now that we’ve seen how we can simply generate word and sentence embeddings - lets build a very simple (and inefficient) “database” that allows nearest neighbor search. As the absolute first (and practically not really usable in a practical situation with larger datasets or requirement for efficient queries - but sufficient to show how stuff works) solution we’re going to just implement linear scan in an in-memory index that’s loaded on startup from a JSON file and stored there. But first let’s define the wrapper we’re going to use. This wrapper will be used around all of the databases around during the experiments since in the end one will likely use some kind of native vector database out of performance reasons anyways - and I don’t want to change code later on.
The basic wrapper will allow us to:
flush
method will allow to write dirty data to disk for any non properly implemented databases (like the linear one)close
method will be used later to close resourcesThe linear database will support file backing to make experiments more easy though.
# This is only the interface. Different databases might support different ways of instantiation and
# configuration parameters (like switching word embeddings, etc.)
class VectorDatabase:
def __init__(self, cfg = None):
pass
def insert(
self,
vector = None,
data = None,
metadata = None,
update = True
):
raise NotImplementedError()
def query(
self,
vector,
includeIdentical = False,
neighborCount = 8,
maxDistance = None
):
raise NotImplementedError()
def flush(self):
# This will not raise an NotImplementedError since this method might not be needed by some implementations
pass
def close(self):
raise NotImplementedError()
To be able to switch between different embeddings also let’s define a wrapper for our sentence embedding backend and provide implementations for the sentence transformer provided by all-mpnet-base-v2
as well as OpenAIs text-embedding-3-small
model. Each embedding will only provide methods to:
cfg
parameter that will get useful later on_get_embedding
that is wrapped by a public get_embedding
method to support input parameter validation and other functionality later onclass Embedding:
def __init__(self, max_input_size = None, embedding_size = None, embedding_name = None, cfg = None):
if max_input_size is None:
raise ValueError("Maximum input size for embedding has to be supplied!")
if embedding_size is None:
raise ValueError("A size of the embedding vector has to be supplied!")
if embedding_name is None:
raise ValueError("An unique embedding name has to be supplied!")
self._embedding_size = embedding_size
self._embedding_name = embedding_name
def _get_embedding(self, data):
raise NotImplementedException()
def get_embedding(self, data):
return self._get_embedding(data)
def get_vector_size(self):
return self._embedding_size
def get_embedding_name(self):
return self._embedding_name
from sentence_transformers import SentenceTransformer
class Embedding_Sentence_all_mpnet_base_v2(Embedding):
def __init__(self, cfg = None):
self._model = SentenceTransformer('all-mpnet-base-v2')
testEmbedding = self._model.encode([ "Test Embedding" ])
super().__init__(
max_input_size = self._model.max_seq_length,
embedding_size = np.asarray(testEmbedding[0]).shape[0],
embedding_name = "Embedding_Sentence_all_mpnet_base_v2"
)
def _get_embedding(self, data):
if not isinstance(data, list):
return self._model.encode([data])[0]
else:
return self._model.encode(data)
from openai import OpenAI
import numpy as np
class Embedding_OpenAI_3Small(Embedding):
def __init__(self, cfg = None):
self._client = OpenAI()
super().__init__(
max_input_size = 8191,
embedding_size = 1536,
embedding_name = "Embedding_OpenAI_3Small"
)
def _get_embedding(self, data):
if isinstance(data, list):
res = []
for indata in data:
res.append(self._get_embedding(indata))
return res
else:
return np.asarray((self._client.embeddings.create(input = data, model="text-embedding-3-small")).data[0].embedding)
Let’s just do a quick test using the all-mpnet-base-v2
embedding wrapper:
tstEmbedder = Embedding_Sentence_all_mpnet_base_v2()
print(tstEmbedder.get_embedding([ "Test sentence!", "Second sentence" ]))
print(tstEmbedder.get_vector_size())
[[-0.0845526 -0.02084206 -0.01640675 ... -0.03267304 0.00293803
-0.0234915 ]
[-0.0228159 0.00919766 0.04134046 ... -0.03268338 -0.0037409
-0.00664491]]
768
Now lets implement the linear database. This will just use an in memory list containing vectors, metadata an data. On insert it will check for collisions which makes the operation $O(n)$. On query it will iterate over every entry, calculate the cosine similarity on the entries and keep track over the nearest $n$ entries. This is also an $O(n)$ approach and requires significant amount of processing time - so it is not suited for real world applications in most cases. The database will also allow to store it’s current state in a JSON file - and also support loading it again from there.
import json_numpy
import os
import numpy as np
class VectorDatabaseLinearJSON(VectorDatabase):
def __init__(
self,
filename = None,
store_data = True,
embedder = None
):
self._filename = filename
self._data = []
self._store_data = store_data
self._embedder = embedder
if (embedder is not None) and not isinstance(embedder, Embedding):
raise ValueError("Embedding provider has to be a subclass of Embedding!")
if self._filename is not None:
# Load from file ...
try:
with open(self._filename) as infile:
dta = json_numpy.load(infile)
if (dta['dbtype'] != "LinearJSON") or (dta['dbversion'] > 1):
raise ValueError(f"Unsupported database type {dta['dbtype']} or version {dta['dbversion']}")
self._data = dta['data']
self._store_data = dta['store_data']
except FileNotFoundError:
self._data = []
self._store_data = store_data
def insert(
self,
vector = None,
data = None,
metadata = None,
update = True
):
# In case no vector is supplied generate it (as long as data is specified)
if (vector is None) and (data is None):
raise ValueError("Neither data nor vector is supplied")
if vector is None:
if self._embedder is None:
raise ValueError("No vector supplied and no embedder configured")
vector = self._embedder.get_embedding(data)
# Check if this entry is already in the list - first linear scan ...
for i in range(len(self._data)):
if (self._data[i]['vector'] == vector).all():
if update:
del self._data[i]
break
else:
raise ValueError("Key already in database")
self._data.append({
'vector' : vector,
'data' : data,
'metadata' : metadata
})
def _get_similarity(
self,
v1,
v2
):
return np.dot(np.asarray(v1), np.asarray(v2)) / (np.linalg.norm(v1) * np.linalg.norm(v2))
def query(
self,
vector = None,
data = None,
includeIdentical = False,
neighborCount = 8,
maxDistance = None
):
# If no vector is supplied calculate from supplied data ...
if (vector is None) and (data is None):
raise ValueError("Neither vector nor data is supplied")
if vector is None:
if self._embedder is None:
raise ValueError("No vector supplied and no embedder configured")
vector = self._embedder.get_embedding(data)
currentNearestNeighbors = []
currentNearestDistances = np.empty((neighborCount,))
for i in range(len(self._data)):
if (self._data[i] == vector).all():
if not includeIdentical:
continue
else:
if len(currentNearestNeighbors) < neighborCount:
currentNearestNeighbors.append(self._data[i])
currentNearestDistances[len(currentNearestNeighbors)-1] = 1.0
else:
d = self._get_similarity(self._data[i]['vector'], vector)
if (maxDistance is not None) and (d > maxDistance):
continue
if len(currentNearestNeighbors) < neighborCount:
currentNearestNeighbors.append(self._data[i])
currentNearestDistances[len(currentNearestNeighbors)-1] = d
else:
if d > np.min(currentNearestDistances):
currentNearestNeighbors[np.argmin(currentNearestDistances)] = self._data[i]
currentNearestDistances[np.argmin(currentNearestDistances)] = d
currentNearestDistances, currentNearestNeighbors = zip(*sorted(zip(currentNearestDistances, currentNearestNeighbors), reverse=True))
return currentNearestNeighbors, currentNearestDistances
def flush(self):
if self._filename is None:
return True
dta = {
'dbtype' : "LinearJSON",
'dbversion' : 1,
'store_data' : self._store_data,
'data' : self._data
}
with open(self._filename, 'w', encoding='utf-8') as outfile:
json_numpy.dump(dta, outfile, ensure_ascii=False, indent=4)
return True
def close():
self.flush()
Now lets create a simple instance an fill it with random knowledge in the categories apples
, trees
and witches
. Those sentences have been generated using a GPT to save time for this sample. In the end the database is stored to disk:
knowledgeSentences = [
"Apples are a member of the rose family, Rosaceae, and are one of the most widely cultivated tree fruits.",
"The apple tree originated in Central Asia, where its wild ancestor, Malus sieversii, can still be found today.",
"There are thousands of different varieties of apples, each with its own unique flavor, texture, and color.",
"Apple trees typically bloom in the spring and produce fruit in the late summer or early fall, depending on the variety.",
"Apples are rich in important nutrients, such as fiber, vitamin C, and various antioxidants, which contribute to good health.",
"The proverb \"an apple a day keeps the doctor away\" originated in Wales in the 19th century and highlights the health benefits associated with eating apples.",
"Apple seeds contain a small amount of cyanide, but you would need to eat a large number of seeds for it to be harmful.",
"The world's largest producer of apples is China, followed by the United States, Poland, and India.",
"The science of apple breeding, known as pomology, focuses on developing new apple varieties with desirable traits, such as improved flavor, disease resistance, and storage qualities.",
"Trees are woody plants that typically have a single main stem or trunk and branches extending from it.",
"Trees play a crucial role in the environment by absorbing carbon dioxide and releasing oxygen through the process of photosynthesis.",
"The tallest tree species in the world is the coast redwood (Sequoia sempervirens), which can grow to over 350 feet (107 meters) tall.",
"Trees provide habitat and food for a wide variety of wildlife, including birds, mammals, insects, and fungi.",
"The rings found in a tree trunk, known as growth rings or annual rings, can provide valuable information about the tree's age and environmental conditions during its lifetime.",
"Trees help regulate the climate by providing shade, reducing the heat island effect in urban areas, and influencing local weather patterns.",
"Many tree species have medicinal properties, and extracts from certain trees are used in traditional medicine around the world.",
"Deforestation, the clearing of forests for agriculture, urbanization, and other purposes, is a major threat to global biodiversity and contributes to climate change.",
"Arboriculture is the practice of cultivating and managing trees for their aesthetic, ecological, and economic benefits, including urban forestry, orchard management, and agroforestry.",
"Witches possess innate magical abilities, allowing them to manipulate the elements, brew potions, and cast spells.",
"Traditionally, witches are depicted as female, although male practitioners of magic may also be referred to as warlocks or wizards in some cultures.",
"Witches often live in secluded areas such as deep forests, misty swamps, or hidden caves, where they can practice their craft away from prying eyes.",
"Familiars, supernatural animal companions, are commonly associated with witches and assist them in their magical endeavors.",
"Witches may gather in covens, groups of individuals who come together to perform rituals, share knowledge, and support one another in their magical pursuits.",
"Witches often draw their magical power from sources such as nature, the moon, or ancient spirits, depending on their individual beliefs and traditions.",
"Spells cast by witches can range from simple charms for good luck or protection to more complex incantations for healing, transformation, or even curses.",
"In some fantasy worlds, witches may face persecution and discrimination from non-magical societies, leading them to conceal their abilities or seek refuge among their own kind.",
"Despite their mystical powers, witches are often depicted as complex characters with their own hopes, fears, and moral dilemmas, blurring the lines between light and dark magic."
]
testDb = VectorDatabaseLinearJSON("EmbeddingBasics01_db.json", embedder = Embedding_Sentence_all_mpnet_base_v2())
for sentence in knowledgeSentences:
testDb.insert(data = sentence)
testDb.flush()
testDb = None
Now let’s load the database again and run a number of queries for neighbors and fetch the best matching contextually close sentences from the database:
testDb2 = VectorDatabaseLinearJSON("EmbeddingBasics01_db.json", embedder = Embedding_Sentence_all_mpnet_base_v2())
queries = [
"Apple cider, made from pressed apples, is a popular beverage in many countries, especially during the autumn season.",
"The oldest known living tree is a bristlecone pine (Pinus longaeva) named Methuselah, which is over 4,800 years old.",
"The stereotypical image of a witch includes a pointed hat, a broomstick for flying, and a cauldron for brewing potions, though these elements may vary depending on cultural and regional differences."
]
for quer in queries:
print(f"Query: {quer}")
res, res_distances = testDb2.query(data = quer, neighborCount = 4)
for r, rd in zip(res, res_distances):
print(f"Distance {rd}: {r['data']}")
print(" ")
print(" ")
print(" ")
Query: Apple cider, made from pressed apples, is a popular beverage in many countries, especially during the autumn season.
Distance 0.5641804933547974: There are thousands of different varieties of apples, each with its own unique flavor, texture, and color.
Distance 0.5511329770088196: Apples are a member of the rose family, Rosaceae, and are one of the most widely cultivated tree fruits.
Distance 0.5276336073875427: Apples are rich in important nutrients, such as fiber, vitamin C, and various antioxidants, which contribute to good health.
Distance 0.5178161859512329: The world's largest producer of apples is China, followed by the United States, Poland, and India.
Query: The oldest known living tree is a bristlecone pine (Pinus longaeva) named Methuselah, which is over 4,800 years old.
Distance 0.5430047512054443: The tallest tree species in the world is the coast redwood (Sequoia sempervirens), which can grow to over 350 feet (107 meters) tall.
Distance 0.45321375131607056: Trees are woody plants that typically have a single main stem or trunk and branches extending from it.
Distance 0.4483247697353363: The apple tree originated in Central Asia, where its wild ancestor, Malus sieversii, can still be found today.
Distance 0.433364599943161: Trees play a crucial role in the environment by absorbing carbon dioxide and releasing oxygen through the process of photosynthesis.
Query: The stereotypical image of a witch includes a pointed hat, a broomstick for flying, and a cauldron for brewing potions, though these elements may vary depending on cultural and regional differences.
Distance 0.6778343915939331: Despite their mystical powers, witches are often depicted as complex characters with their own hopes, fears, and moral dilemmas, blurring the lines between light and dark magic.
Distance 0.6336530447006226: Witches possess innate magical abilities, allowing them to manipulate the elements, brew potions, and cast spells.
Distance 0.6256386637687683: Traditionally, witches are depicted as female, although male practitioners of magic may also be referred to as warlocks or wizards in some cultures.
Distance 0.5962900519371033: Spells cast by witches can range from simple charms for good luck or protection to more complex incantations for healing, transformation, or even curses.
Now trying the same with OpenAI embeddings:
knowledgeSentences = [
"Apples are a member of the rose family, Rosaceae, and are one of the most widely cultivated tree fruits.",
"The apple tree originated in Central Asia, where its wild ancestor, Malus sieversii, can still be found today.",
"There are thousands of different varieties of apples, each with its own unique flavor, texture, and color.",
"Apple trees typically bloom in the spring and produce fruit in the late summer or early fall, depending on the variety.",
"Apples are rich in important nutrients, such as fiber, vitamin C, and various antioxidants, which contribute to good health.",
"The proverb \"an apple a day keeps the doctor away\" originated in Wales in the 19th century and highlights the health benefits associated with eating apples.",
"Apple seeds contain a small amount of cyanide, but you would need to eat a large number of seeds for it to be harmful.",
"The world's largest producer of apples is China, followed by the United States, Poland, and India.",
"The science of apple breeding, known as pomology, focuses on developing new apple varieties with desirable traits, such as improved flavor, disease resistance, and storage qualities.",
"Trees are woody plants that typically have a single main stem or trunk and branches extending from it.",
"Trees play a crucial role in the environment by absorbing carbon dioxide and releasing oxygen through the process of photosynthesis.",
"The tallest tree species in the world is the coast redwood (Sequoia sempervirens), which can grow to over 350 feet (107 meters) tall.",
"Trees provide habitat and food for a wide variety of wildlife, including birds, mammals, insects, and fungi.",
"The rings found in a tree trunk, known as growth rings or annual rings, can provide valuable information about the tree's age and environmental conditions during its lifetime.",
"Trees help regulate the climate by providing shade, reducing the heat island effect in urban areas, and influencing local weather patterns.",
"Many tree species have medicinal properties, and extracts from certain trees are used in traditional medicine around the world.",
"Deforestation, the clearing of forests for agriculture, urbanization, and other purposes, is a major threat to global biodiversity and contributes to climate change.",
"Arboriculture is the practice of cultivating and managing trees for their aesthetic, ecological, and economic benefits, including urban forestry, orchard management, and agroforestry.",
"Witches possess innate magical abilities, allowing them to manipulate the elements, brew potions, and cast spells.",
"Traditionally, witches are depicted as female, although male practitioners of magic may also be referred to as warlocks or wizards in some cultures.",
"Witches often live in secluded areas such as deep forests, misty swamps, or hidden caves, where they can practice their craft away from prying eyes.",
"Familiars, supernatural animal companions, are commonly associated with witches and assist them in their magical endeavors.",
"Witches may gather in covens, groups of individuals who come together to perform rituals, share knowledge, and support one another in their magical pursuits.",
"Witches often draw their magical power from sources such as nature, the moon, or ancient spirits, depending on their individual beliefs and traditions.",
"Spells cast by witches can range from simple charms for good luck or protection to more complex incantations for healing, transformation, or even curses.",
"In some fantasy worlds, witches may face persecution and discrimination from non-magical societies, leading them to conceal their abilities or seek refuge among their own kind.",
"Despite their mystical powers, witches are often depicted as complex characters with their own hopes, fears, and moral dilemmas, blurring the lines between light and dark magic."
]
testDb = VectorDatabaseLinearJSON("EmbeddingBasics01_db_openai.json", embedder = Embedding_OpenAI_3Small())
for sentence in knowledgeSentences:
testDb.insert(data = sentence)
testDb.flush()
testDb = None
testDb2 = VectorDatabaseLinearJSON("EmbeddingBasics01_db_openai.json", embedder = Embedding_OpenAI_3Small())
queries = [
"Apple cider, made from pressed apples, is a popular beverage in many countries, especially during the autumn season.",
"The oldest known living tree is a bristlecone pine (Pinus longaeva) named Methuselah, which is over 4,800 years old.",
"The stereotypical image of a witch includes a pointed hat, a broomstick for flying, and a cauldron for brewing potions, though these elements may vary depending on cultural and regional differences."
]
for quer in queries:
print(f"Query: {quer}")
res, res_distances = testDb2.query(data = quer, neighborCount = 4)
for r, rd in zip(res, res_distances):
print(f"Distance {rd}: {r['data']}")
print(" ")
print(" ")
print(" ")
Query: Apple cider, made from pressed apples, is a popular beverage in many countries, especially during the autumn season.
Distance 0.4322422460396458: Apple trees typically bloom in the spring and produce fruit in the late summer or early fall, depending on the variety.
Distance 0.4185395618528946: The world's largest producer of apples is China, followed by the United States, Poland, and India.
Distance 0.40947432484412793: There are thousands of different varieties of apples, each with its own unique flavor, texture, and color.
Distance 0.4034569638808718: Apples are rich in important nutrients, such as fiber, vitamin C, and various antioxidants, which contribute to good health.
Query: The oldest known living tree is a bristlecone pine (Pinus longaeva) named Methuselah, which is over 4,800 years old.
Distance 0.47512168441444363: The tallest tree species in the world is the coast redwood (Sequoia sempervirens), which can grow to over 350 feet (107 meters) tall.
Distance 0.3627292936236696: The rings found in a tree trunk, known as growth rings or annual rings, can provide valuable information about the tree's age and environmental conditions during its lifetime.
Distance 0.3482689975267936: Many tree species have medicinal properties, and extracts from certain trees are used in traditional medicine around the world.
Distance 0.34680379241264475: The apple tree originated in Central Asia, where its wild ancestor, Malus sieversii, can still be found today.
Query: The stereotypical image of a witch includes a pointed hat, a broomstick for flying, and a cauldron for brewing potions, though these elements may vary depending on cultural and regional differences.
Distance 0.6027058763748782: Traditionally, witches are depicted as female, although male practitioners of magic may also be referred to as warlocks or wizards in some cultures.
Distance 0.5781743884723687: Despite their mystical powers, witches are often depicted as complex characters with their own hopes, fears, and moral dilemmas, blurring the lines between light and dark magic.
Distance 0.5555082055943925: Witches possess innate magical abilities, allowing them to manipulate the elements, brew potions, and cast spells.
Distance 0.5486788078306959: Witches often draw their magical power from sources such as nature, the moon, or ancient spirits, depending on their individual beliefs and traditions.
As one can see using sentence embedding vectors works pretty well for contextual search and knowledge retrieval. This can of course be extended to arbitrary scale as long as one provides an efficient index for nearest neighbor search - which is besides training the transformers when one does not use pre trained ones the largest obstacle on large scale usage. Since the curse of dimensionality leads to massive problems on high dimensional data and also highly dimensional data indices are not easy - especially not dynamic ones - this is not an easy beginners task though. Most approaches are either based on quantization in multiple levels, locality sensitive hashes (LSH), solutions that perform dimensional reduction by figuring out separating hyperplanes using singular value decomposition, solutions that assign vectors to local clusters that are determined by usage of k-means or similar clustering mechanisms to aid looking up and reducing number of neighboring candidates, etc. - usually solutions also don’t return the best matches but only approximate nearest neighbors.
There is a number of libraries and solutions that implement this pretty efficient - but most either rely on keeping all data including the vectors in main memory or require at least SSD storage to achieve performance (like the famous DiskANN). The following is just a short list of references and by no means a complete list of references:
faiss
by Facebook was one of the first libraries out there to perform efficient high dimensional data indexingIt’s really fascinating how those indexes really work - but in case one wants to build applications upon approximate nearest neighbor searches it’s a very good idea to use an existing implementation. Since they’re not as simple as many other datastructures like AVL or RB trees or tries implementing them - and especially implementing them properly - is a hard task.
And what can one do in real world using such databases?
This article is tagged: Programming, Machine learning, OpenAI, Tutorial, Basics
Dipl.-Ing. Thomas Spielauer, Wien (webcomplains389t48957@tspi.at)
This webpage is also available via TOR at http://rh6v563nt2dnxd5h2vhhqkudmyvjaevgiv77c62xflas52d5omtkxuid.onion/