From 78135453a2337d11a6af2089ff8b1c23d9234278 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 2 Aug 2020 08:41:38 +0200 Subject: [PATCH 01/62] initial commit --- .gitignore | 2 + README.md | 178 ++++++++++++++++++++++++++++++++++++++++++++++ orm/__init__.py | 0 tests/__init__.py | 0 4 files changed, 180 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 orm/__init__.py create mode 100644 tests/__init__.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d33695b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +p38venv +.idea \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..80f90a8 --- /dev/null +++ b/README.md @@ -0,0 +1,178 @@ +# ORM + +

+ + Build Status + + + Coverage + + + Package version + +

+ +The `async-orm` package is an async ORM for Python, with support for Postgres, +MySQL, and SQLite. ORM is built with: + +* [SQLAlchemy core][sqlalchemy-core] for query building. +* [`databases`][databases] for cross-database async support. +* [`pydantic`][pydantic] for data validation. + +Because ORM is built on SQLAlchemy core, you can use Alembic to provide +database migrations. + +**ORM is still under development: We recommend pinning any dependencies with `orm~=0.1`** + +**Note**: Use `ipython` to try this from the console, since it supports `await`. + +```python +import databases +import orm +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Note(orm.Model): + __tablename__ = "notes" + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + text = orm.String(max_length=100) + completed = orm.Boolean(default=False) + +# Create the database +engine = sqlalchemy.create_engine(str(database.url)) +metadata.create_all(engine) + +# .create() +await Note.objects.create(text="Buy the groceries.", completed=False) +await Note.objects.create(text="Call Mum.", completed=True) +await Note.objects.create(text="Send invoices.", completed=True) + +# .all() +notes = await Note.objects.all() + +# .filter() +notes = await Note.objects.filter(completed=True).all() + +# exact, iexact, contains, icontains, lt, lte, gt, gte, in +notes = await Note.objects.filter(text__icontains="mum").all() + +# .get() +note = await Note.objects.get(id=1) + +# .update() +await note.update(completed=True) + +# .delete() +await note.delete() + +# 'pk' always refers to the primary key +note = await Note.objects.get(pk=2) +note.pk # 2 +``` + +ORM supports loading and filtering across foreign keys... + +```python +import databases +import orm +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(orm.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(max_length=100) + + +class Track(orm.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album) + title = orm.String(max_length=100) + position = orm.Integer() + + +# Create some records to work with. +malibu = await Album.objects.create(name="Malibu") +await Track.objects.create(album=malibu, title="The Bird", position=1) +await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) +await Track.objects.create(album=malibu, title="The Waters", position=3) + +fantasies = await Album.objects.create(name="Fantasies") +await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) +await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + + +# Fetch an instance, without loading a foreign key relationship on it. +track = await Track.objects.get(title="The Bird") + +# We have an album instance, but it only has the primary key populated +print(track.album) # Album(id=1) [sparse] +print(track.album.pk) # 1 +print(track.album.name) # Raises AttributeError + +# Load the relationship from the database +await track.album.load() +assert track.album.name == "Malibu" + +# This time, fetch an instance, loading the foreign key relationship. +track = await Track.objects.select_related("album").get(title="The Bird") +assert track.album.name == "Malibu" + +# Fetch instances, with a filter across an FK relationship. +tracks = Track.objects.filter(album__name="Fantasies") +assert len(tracks) == 2 + +# Fetch instances, with a filter and operator across an FK relationship. +tracks = Track.objects.filter(album__name__iexact="fantasies") +assert len(tracks) == 2 + +# Limit a query +tracks = await Track.objects.limit(1).all() +assert len(tracks) == 1 +``` + +## Data types + +The following keyword arguments are supported on all field types. + +* `primary_key` +* `allow_null` +* `default` +* `index` +* `unique` + +All fields are required unless one of the following is set: + +* `allow_null` - Creates a nullable column. Sets the default to `None`. +* `allow_blank` - Allow empty strings to validate. Sets the default to `""`. +* `default` - Set a default value for the field. + +* `orm.String(max_length)` +* `orm.Text()` +* `orm.Boolean()` +* `orm.Integer()` +* `orm.Float()` +* `orm.Date()` +* `orm.Time()` +* `orm.DateTime()` +* `orm.JSON()` + +[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ +[databases]: https://github.com/encode/databases +[pydantic]: https://pydantic-docs.helpmanual.io/ \ No newline at end of file diff --git a/orm/__init__.py b/orm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 From db2a0b3ddbec7fdecc5dd04a35071249cbbb2829 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 2 Aug 2020 08:44:20 +0200 Subject: [PATCH 02/62] license --- LICENSE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..79eb022 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Radosław Drążkiewicz + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file From c22de9684db2ae0a76fb56eb2f19ea80222882df Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 2 Aug 2020 08:56:53 +0200 Subject: [PATCH 03/62] add travis and codecov --- .codecov.yml | 11 +++++++++++ .travis.yml | 19 +++++++++++++++++++ LICENSE => LICENSE.md | 0 README.md | 6 +++++- requirements.txt | 7 +++++++ 5 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 .codecov.yml create mode 100644 .travis.yml rename LICENSE => LICENSE.md (100%) create mode 100644 requirements.txt diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000..033aafd --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,11 @@ +coverage: + precision: 2 + round: down + range: "80...100" + + status: + project: yes + patch: yes + changes: yes + +comment: off \ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..c313696 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,19 @@ +language: python + +dist: xenial + +cache: pip + +python: + - "3.6" + - "3.7" + - "3.8" + +install: + - pip install -U -r requirements.txt + +script: + - scripts/test + +after_script: + - codecov \ No newline at end of file diff --git a/LICENSE b/LICENSE.md similarity index 100% rename from LICENSE rename to LICENSE.md diff --git a/README.md b/README.md index 80f90a8..b3ebffa 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,9 @@ MySQL, and SQLite. ORM is built with: Because ORM is built on SQLAlchemy core, you can use Alembic to provide database migrations. +The goal was to create a simple orm that can be used directly with FastApi that bases it's data validation on pydantic. +Initial work was inspired by [`encode/orm`][encode/orm] + **ORM is still under development: We recommend pinning any dependencies with `orm~=0.1`** **Note**: Use `ipython` to try this from the console, since it supports `await`. @@ -175,4 +178,5 @@ All fields are required unless one of the following is set: [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases -[pydantic]: https://pydantic-docs.helpmanual.io/ \ No newline at end of file +[pydantic]: https://pydantic-docs.helpmanual.io/ +[encode/orm]: \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7a4f7a2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +databases[sqlite] +pydantic +sqlalchemy + +# Testing +pytest +pytest-cov \ No newline at end of file From d2444b4d0588ab884916fc72aca3bff750d90c2d Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 2 Aug 2020 09:00:13 +0200 Subject: [PATCH 04/62] test script --- scripts/tests | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 scripts/tests diff --git a/scripts/tests b/scripts/tests new file mode 100644 index 0000000..466a026 --- /dev/null +++ b/scripts/tests @@ -0,0 +1,12 @@ +#!/bin/sh -e + +PACKAGE="async-orm" + +PREFIX="" +if [ -d 'venv' ] ; then + PREFIX="venv/bin/" +fi + +set -x + +PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=${PACKAGE} --cov=tests --cov-fail-under=100 --cov-report=term-missing ${@} \ No newline at end of file From c1b3b53875e9074debfda17bbe0db585d7976b76 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 2 Aug 2020 09:05:06 +0200 Subject: [PATCH 05/62] rename test, update readme --- README.md | 2 +- scripts/{tests => test} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename scripts/{tests => test} (73%) diff --git a/README.md b/README.md index b3ebffa..07b3ba6 100644 --- a/README.md +++ b/README.md @@ -179,4 +179,4 @@ All fields are required unless one of the following is set: [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases [pydantic]: https://pydantic-docs.helpmanual.io/ -[encode/orm]: \ No newline at end of file +[encode/orm]: https://github.com/encode/orm/ \ No newline at end of file diff --git a/scripts/tests b/scripts/test similarity index 73% rename from scripts/tests rename to scripts/test index 466a026..577f0e9 100644 --- a/scripts/tests +++ b/scripts/test @@ -9,4 +9,4 @@ fi set -x -PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=${PACKAGE} --cov=tests --cov-fail-under=100 --cov-report=term-missing ${@} \ No newline at end of file +PYTHONPATH=. ${PREFIX}pytest --ignore venv --cov=${PACKAGE} --cov=tests --cov-fail-under=100 --cov-report=term-missing "${@}" From 4f9dddfa0a3beaebdd79ac03410c4c0c87411b41 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 2 Aug 2020 09:08:56 +0200 Subject: [PATCH 06/62] make test.sh executable --- .travis.yml | 2 +- scripts/{test => test.sh} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename scripts/{test => test.sh} (100%) mode change 100644 => 100755 diff --git a/.travis.yml b/.travis.yml index c313696..416d999 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,7 @@ install: - pip install -U -r requirements.txt script: - - scripts/test + - scripts/test.sh after_script: - codecov \ No newline at end of file diff --git a/scripts/test b/scripts/test.sh old mode 100644 new mode 100755 similarity index 100% rename from scripts/test rename to scripts/test.sh From 96ec33fe16590dff3a51d3bf6eea237b182c33b8 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 08:17:30 +0200 Subject: [PATCH 07/62] attribute access and setting for pydantic_model uned the hood --- .coverage | Bin 0 -> 53248 bytes README.md | 15 ++-- orm/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 128 bytes orm/__pycache__/exceptions.cpython-38.pyc | Bin 0 -> 438 bytes orm/__pycache__/fields.cpython-38.pyc | Bin 0 -> 2613 bytes orm/__pycache__/models.cpython-38.pyc | Bin 0 -> 2536 bytes orm/exceptions.py | 6 ++ orm/fields.py | 73 ++++++++++++++++++ orm/models.py | 69 +++++++++++++++++ tests/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 130 bytes .../test_columns.cpython-38-pytest-6.0.1.pyc | Bin 0 -> 3459 bytes tests/__pycache__/test_columns.cpython-38.pyc | Bin 0 -> 670 bytes tests/test_columns.py | 52 +++++++++++++ 13 files changed, 209 insertions(+), 6 deletions(-) create mode 100644 .coverage create mode 100644 orm/__pycache__/__init__.cpython-38.pyc create mode 100644 orm/__pycache__/exceptions.cpython-38.pyc create mode 100644 orm/__pycache__/fields.cpython-38.pyc create mode 100644 orm/__pycache__/models.cpython-38.pyc create mode 100644 orm/exceptions.py create mode 100644 orm/fields.py create mode 100644 orm/models.py create mode 100644 tests/__pycache__/__init__.cpython-38.pyc create mode 100644 tests/__pycache__/test_columns.cpython-38-pytest-6.0.1.pyc create mode 100644 tests/__pycache__/test_columns.cpython-38.pyc create mode 100644 tests/test_columns.py diff --git a/.coverage b/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..922eaa1c16bc5e9ad0f3557871e19de633a21510 GIT binary patch literal 53248 zcmeI)+l$=R9S3k}q|tcxc2dkTbryHeOFdqDXT9T1i|b%or}cx|K#A8Rp`KRJXwK}) z8A)SFv)=8Mbqf9i3VrQML;r+Aq4YhpFCm4}hXhhc14$`?5(o})e?ML3V(%;=FUfAc zVppTfIY;Mw&PAhSN2( z7OUFYQ*vHBugzJ%)x=NcUlfhGSLZLx{&Q}p`kUE5RBl&$v;z|aAOHafKp+*^e{QZ) zJ9SFG_lv}7_f!%&u8Okx{C96$yL$7Qym|GRXRpa@n>;%sX=*g&RT+hMWnV?I>-CiM zf{y1pi5G0iWLp)X2eIn#8IQKnIgbObrSW;m_Bs?RQCpNmKk|Gh+LJG2G$TG_+w< zlfnLgD=EoY-8AU!IW*$6`+@;I~VZl0z$cLuLX`X|D-zM_g1HaFe zLou=gC*ic1)8Zy!Uf54vHwV<^sJs)^S&(y$X0Zl=cW0o6mA@)SF;~-K1mUXWLQo6% z`m9y^)=6EXR;hO5J3V^iWjli;OxJCye7i}rxnFRPl zH;0Qe`=_i*ZD~orS52EyUR(J~WugfcHF>}=TI}+b6Az+dAUSN1EMkmC$)bzMu679+&vPJW&YVGU@`4oCrRB_WMqEgb z%U12_#c3JmEZ9w+Td&+GrCGjTN%egi^!eYI$TDvkW%&ncud(CwXf)|NUPN7%#^CG> zAUR;ip}=pSUjDVc1hn zkew1J;Psde<;5~R^y4R^Z;#T5_c)EMoqC#;;+(>x`%UU=MYo~_-*#67Y{W0 z!vp~cKmY;|fB*y_009U<00Izz!0{6>^palT>wi;xRTG=y8F7XdFhKwU5P$##AOHaf zKmY;|fB*!(L;`cBaat5lMyx++7>gVHWWdGsi|dW`OO5pnxqj(#^P871ZitexRLk$H ze`D?9T66!LZl1n(y?7+zR{i^zTdjT+-d1iBw^~lT7q|_2-kvsYT-s5=POH`P+N~Bn zvZoq-$9#0TWxdg$Cj+if^wA23i-)U?uH*F@1A4|R zy3$-B>nW&4+g@@wr;?-kS#V8B}AOHafKmY;|fB*y_009U< zK-22Rrhf31^~CkRb>6sjNVtjX|MKB!4zK^s72}0N(j2=+H!c|Nl#mBr|Ci1gx2gx1 z>apv8{Qe)h0R$ib0SG_<0uX=z1Rwwb2teRi3mEj6z&zjo*Tu&g{b7Ot1Rwwb2tWV= z5P$##AOHafK;Re(7{+`x{rlWuK>PAE0`bv0SG_<0uX=z1Rwwb2tWV= z5cq-!ELf&q{hIZoU+( literal 0 HcmV?d00001 diff --git a/README.md b/README.md index 07b3ba6..d0c2e2a 100644 --- a/README.md +++ b/README.md @@ -15,17 +15,18 @@ The `async-orm` package is an async ORM for Python, with support for Postgres, MySQL, and SQLite. ORM is built with: -* [SQLAlchemy core][sqlalchemy-core] for query building. +* [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. -Because ORM is built on SQLAlchemy core, you can use Alembic to provide +Because ORM is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide database migrations. -The goal was to create a simple orm that can be used directly with FastApi that bases it's data validation on pydantic. -Initial work was inspired by [`encode/orm`][encode/orm] +The goal was to create a simple orm that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. +Initial work was inspired by [`encode/orm`][encode/orm]. +The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. -**ORM is still under development: We recommend pinning any dependencies with `orm~=0.1`** +**aysn-orm is still under development: We recommend pinning any dependencies with `async-orm~=0.1`** **Note**: Use `ipython` to try this from the console, since it supports `await`. @@ -179,4 +180,6 @@ All fields are required unless one of the following is set: [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases [pydantic]: https://pydantic-docs.helpmanual.io/ -[encode/orm]: https://github.com/encode/orm/ \ No newline at end of file +[encode/orm]: https://github.com/encode/orm/ +[alembic]: https://alembic.sqlalchemy.org/en/latest/ +[fastapi]: https://fastapi.tiangolo.com/ \ No newline at end of file diff --git a/orm/__pycache__/__init__.cpython-38.pyc b/orm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d8d4b7aaa909737d54fac940d004045fd209070 GIT binary patch literal 128 zcmWIL<>g`kf=RJz@gVv!h(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2DN%POXzC_gJT uxuiHIvA8lXSvS8ZHwJ{`<1_OzOXB183My}L*yQG?l;)(`fwX@HVg>*mlN&n# literal 0 HcmV?d00001 diff --git a/orm/__pycache__/exceptions.cpython-38.pyc b/orm/__pycache__/exceptions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2515eefefe82c35a623377c3459a4582803970f7 GIT binary patch literal 438 zcmb79!Ab)$6ihZ-yOtId`~;!&pdTP26)b|{dXQeqLRdBtRyLb8X~CZLXNrH6tAF9i zdAqbKc=5qZ-XxF7y!mJ}V5r;ou{1b8Nc^lAVhZ?U&^1#`@i#!TSFzyGkrL2CbLdFL z4{XvaF!`%E)D-YTkY)?67-l(Uqi}fX*L894%W zuJmJBmy~wqoOP)qyYD<**QdJ#)C0C3l1?Wg?C+7{B~Q_bIj7y{H0m(5>1T)bB<$Ugb5|TelW8O!Zo##LZ0Pd?%$#%1{^5H$duVU?X+{ zUUM>LhmNKrdPEO_%Pq}Dl(P+@og*f+JUq6h190Bz2G2SEE$=z%O!bYnIg_&~$Afz(&PAl0NFj1WlV zgpSD})2!=^P;$p)myOsDkoc~LkN~OxAv~%;H6VcS)5;SpB(~G8JR7HZbQTe7H*8GI zxsK3rT0*>Kk`-E+n%d8kJiRY^L*q|?GTVfpiS@8A{>%A}d;#2=D5)TUc^r@&sv#q+>Usj0?snnU-55HI3{!TtnDVXz2YgZ6Q7ZsOLwJFNJJa7F*N zffV*K4V?gL8?KBgZraWydzWF+8YV;hdJbH#!u$+F7pF=Wm16PG4&RKxDeCD3j_df2km_=)@09i*75Dib_Tne-@&^a9c7YvnKg%%wlSE7Z7 zAz3tFGO)nAaLFz*ZfmX0@QM+`mdoH#GXxwBXO?nTl-84F1tY}T7@nE^@OtgEX<8MylmeA1mo~KGy9xmc0tu->BIUA5S#5U4!Oq%oX4X`W z)~7`3iC(xOlpOOnIPfFp%89?gq2+mJ?IcZEYu@>u_nG(O8UMJj5Ho!Dj&G)aMvVPU zowHwn&KD@=2MEa|FW7)5oKssAA`z&)!W;OB@5cNh7=%g4*|$svGCW{1RE>QviO_Gz z2>nP!viYl*G^JQ)o%kP|%R1b8Uk)d0sN1<#S*g;&j#LHu%?FuL-}F==b6R=!3uhg3 z3k}5*E?FX!r+gKtPr*AJjq|+a&GRuq1m`-55L2j%+ z%m&H^o6u-%kd>u2Hq@#d>tXJn(cDIxlzS-VFNl)uvqN67idUlYj{K=#d51#sBd)Jk z{yM8XDc)ey0JVoYlx*7AV$-Pd``)2{#J*;aZ*Pcca~hXH6_??@&=0D(-_So*&1s{G z4x=hQY*x)9oXFYk1L@1)fSvB%=hIdh_ghs^guyjU() zyvwWBGmhRRoN%eXgcb@66Rz2Inr54()LCApsa;6ZGTST^*#%up4q!8xm06nR92cK| zWC;qDbrf?OqQ@%s_y+F8r=nz)C^_!rlYROE#T(s_tSf#KyQG&d!zXdi^oFL)hPl#Y zY#TxOWJG!!QgyR&QFc5lmhalQOW}|bs3az5rEcJu>C4#n?>pb!-x%qgzRF9pk(tRb zzq6wU8&q6SndQ;M#w)Uym-)_MwE8AtM(neYhdh2R*xs#|!(%y3hw4e1F0T|jc~+R! z<?C$ewZJtDdjph-Mgtcbdn8DFY*kuEEYI5-$ zlw)KvfK#>s75PbDA@LdnF_8HPf_Whq%>7>Ye$0Il@d(oAOFV{iP>Ya9FGBwcdi1-% zW3hMj|J~JTk{4#73q#>5<6;)ZyW^~w&G-~#)y3 z+vf=gwBMBRyys{wMi#Yf;0*ETmCPiq-;-7%N{JlY4)yDmI7Iu17>o7!3i7oty2w;8 z14})LRPKm1D@NPMO>eX_($}$BzXf5#U$SDX4ARjzs_0rj(_3K5?I%>%&OPxk`JgxK zArtOmt8ttSr!MlnnuL*D^>jYB4VF1(B@&=9hZ@Ym`iyXw)Dpxb`yOYd2j+FWawO;h zT4%?6trO_CvC%dR*a|GJ-$zehr6FX`Y!Gqp=?`ekqeKkcowW1p33z=U^Bm)7MU}aBY5teY}A43mkE0a0mPZKS<#m{>3z%4fZz2 zcqY@1s4w9Y&76n73LJEW8T=*Q=Cf=8{Fl#pDQJk=3IA8LmzWC_jvyh!6K4GhhG4zV zP}R1*P6ON_Ip=?U4^y9`7-uB07uDt)ZNFkQa=6n-%y6o);^b&@?OgAtm_MVpwt!xt zQRli-|B-TcWUai@(j8hwd7$4V(IW9N39^$WgV_kL2N=DaIP5lrpG&3nd$i#6{8J8~ zz`lcG79ohQqM(hR;IZd(O|$erhIq67w-MVAoYNa_Il|G`upo)%d|5MgjMr+p=4bb4 jXnIo}rMUCXz&N?9jaF)gTP5!qBGoqdvn|@*;*I|RFZ@LF literal 0 HcmV?d00001 diff --git a/orm/exceptions.py b/orm/exceptions.py new file mode 100644 index 0000000..ebc4114 --- /dev/null +++ b/orm/exceptions.py @@ -0,0 +1,6 @@ +class AsyncOrmException(Exception): + pass + + +class ModelDefinitionError(AsyncOrmException): + pass diff --git a/orm/fields.py b/orm/fields.py new file mode 100644 index 0000000..221c9a2 --- /dev/null +++ b/orm/fields.py @@ -0,0 +1,73 @@ +import sqlalchemy + +from orm.exceptions import ModelDefinitionError + + +class BaseField: + __type__ = None + + def __init__(self, *args, **kwargs): + name = kwargs.pop('name', None) + args = list(args) + if args: + if isinstance(args[0], str): + if name is not None: + raise ModelDefinitionError( + 'Column name cannot be passed positionally and as a keyword.' + ) + name = args.pop(0) + + self.name = name + self.primary_key = kwargs.pop('primary_key', False) + self.autoincrement = kwargs.pop('autoincrement', 'auto') + + self.nullable = kwargs.pop('nullable', not self.primary_key) + self.default = kwargs.pop('default', None) + self.server_default = kwargs.pop('server_default', None) + + self.index = kwargs.pop('index', None) + self.unique = kwargs.pop('unique', None) + + def get_column(self, name=None) -> sqlalchemy.Column: + name = self.name or name + constraints = self.get_constraints() + return sqlalchemy.Column( + name, + self.get_column_type(), + *constraints, + primary_key=self.primary_key, + autoincrement=self.autoincrement, + nullable=self.nullable, + index=self.index, + unique=self.unique, + default=self.default, + server_default=self.server_default + ) + + def get_column_type(self) -> sqlalchemy.types.TypeEngine: + raise NotImplementedError() # pragma: no cover + + def get_constraints(self): + return [] + + +class String(BaseField): + __type__ = str + + def __init__(self, *args, **kwargs): + assert 'length' in kwargs, 'length is required' + self.length = kwargs.pop('length') + super().__init__(*args, **kwargs) + + def get_column_type(self): + return sqlalchemy.String(self.length) + + +class Integer(BaseField): + __type__ = int + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_column_type(self): + return sqlalchemy.Integer() diff --git a/orm/models.py b/orm/models.py new file mode 100644 index 0000000..605b23a --- /dev/null +++ b/orm/models.py @@ -0,0 +1,69 @@ +from typing import Any + +import sqlalchemy +from pydantic import create_model + +from orm.fields import BaseField + + +class ModelMetaclass(type): + def __new__( + mcs: type, name: str, bases: Any, attrs: dict + ) -> type: + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + + if attrs.get("__abstract__"): + return new_model + + tablename = attrs["__tablename__"] + metadata = attrs["__metadata__"] + pkname = None + + columns = [] + for field_name, field in new_model.__dict__.items(): + if isinstance(field, BaseField): + if field.primary_key: + pkname = field_name + columns.append(field.get_column(field_name)) + + pydantic_fields = {field_name: (base_field.__type__, base_field.default or ...) + for field_name, base_field in new_model.__dict__.items() + if isinstance(base_field, BaseField)} + + new_model.__table__ = sqlalchemy.Table(tablename, metadata, *columns) + new_model.__columns__ = columns + new_model.__pkname__ = pkname + new_model.__pydantic_fields__ = pydantic_fields + new_model.__pydantic_model__ = create_model(name, **pydantic_fields) + new_model.__fields__ = new_model.__pydantic_model__.__fields__ + + return new_model + + +class Model(metaclass=ModelMetaclass): + __abstract__ = True + + def __init__(self, *args, **kwargs): + if "pk" in kwargs: + kwargs[self.__pkname__] = kwargs.pop("pk") + self.values = self.__pydantic_model__(**kwargs) + + def __setattr__(self, key, value): + if key in self.__fields__: + setattr(self.values, key, value) + super().__setattr__(key, value) + + def __getattribute__(self, item): + if item != '__fields__' and item in self.__fields__: + return getattr(self.values, item) + return super().__getattribute__(item) + + @property + def pk(self): + return getattr(self.values, self.__pkname__) + + @pk.setter + def pk(self, value): + setattr(self.values, self.__pkname__, value) diff --git a/tests/__pycache__/__init__.cpython-38.pyc b/tests/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1f8f4d9f4595a84cd83d906b354bf5ad18de23c GIT binary patch literal 130 zcmWIL<>g`kg4wZZ@gVv!h(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2DV%POXzC_gJT yxuiHIvA8lXSvS8ZH>M=D7|4&0&&?$ zMoo%+p*JrGVfjY&ZD+iKbtp`@7$^3<$2fumR4z}0WBSqKVx&5HeI zF(~S>9haKrpiGHyH!3|69QcUv9M3-{JTJ|hLdhr7VLnm@=rUU*D`L)_;@cpv5^ zUWWG(PH&THWee;g^HrK1x!K|(zAhG{Bub-Jazo0Ng!#OlZ=l%d&jw}=9(5jw&`t#w zr-3cNA>k~$+~W2V8h9`+pn)*a^jVSQxhGaXmU&Ds za6b1SnuV-98qGQgnQ+!K2Thx*-`TXX97aPruz{Ja*QE=LMC<0IkHev(W^@GSjsiIZ z&xjWls%Lk2fm%@A2Pk^WA|!=_uJdnoNrx=cF6Fdqb!f+0vATB0R^O)9hRt&wdy(=y z&M#WM9Cy;(M!rW@9Q7%zI~~%a@(UQbI7%#7$1&|u{kD3PyFb|Dq?^Md-pNH2#`)BP z)j}uVqmSq}P-;2SrSRCWKUwy{5^L0wr&NuD{J<7}A6tHfbz{r#df%5~3omvEqy5+v zcUkn)h(-HE2W(CTHa(?s?Il}P}2lD22El%K^kIipUI9_rLNY&CWiAQsjZd*yB?WpD9v}L2!Y==^0C0J$)t5KXri2@95 zfrQGAF?FxC8gnhmmg3ez7;8w)3QUEI0ySin_jD8ncR+TSEdtPmDHCh$I7~vk9dzvF zDedXN;;N^^niX+t6UnJj=He(bZ5X5Z(kOE!7*Q}e%G5_d3h^mBp#uOHfrxdiqlMK} zu&~h(N*xCQ^&t=f$m;+Mc<{`V5~+{@?+kd$=}S1|mXUy&9C6^A!Hk>%O*>Ez(R7G< zzrf-Af1)1uN07$fg*30}ntnD8|2C*(`RATG{1+m=z(EAGYN{}A;MUrO`xRY+(b zUoxw`VvhI;n1tdz`wtM%UfEIqx(MVOsO&8yCxL9irAXxKK&nMk!sR5& z`__nMs%BUvyJ|#Uf$w%y1Ojvc9^D6^Gb!1CyShie8>a-?Qn-7eEwxhnk=-RZ1BO|U{(RRF%Go^z5OfjRQXfbG`gO=p>Xj4%6Wrfyw zS$*gNF3}AZuWPYDhOv2Nr5 zt!fmF9*xg-dvrY>q!ix8Fey%ntP>*3#eXZ6 dPl Date: Mon, 3 Aug 2020 08:18:57 +0200 Subject: [PATCH 08/62] attribute access and setting for pydantic_model uned the hood --- .gitignore | 5 ++++- orm/__pycache__/__init__.cpython-38.pyc | Bin 128 -> 0 bytes orm/__pycache__/exceptions.cpython-38.pyc | Bin 438 -> 0 bytes orm/__pycache__/fields.cpython-38.pyc | Bin 2613 -> 0 bytes orm/__pycache__/models.cpython-38.pyc | Bin 2536 -> 0 bytes scripts/test.sh | 0 tests/__pycache__/__init__.cpython-38.pyc | Bin 130 -> 0 bytes .../test_columns.cpython-38-pytest-6.0.1.pyc | Bin 3459 -> 0 bytes tests/__pycache__/test_columns.cpython-38.pyc | Bin 670 -> 0 bytes 9 files changed, 4 insertions(+), 1 deletion(-) delete mode 100644 orm/__pycache__/__init__.cpython-38.pyc delete mode 100644 orm/__pycache__/exceptions.cpython-38.pyc delete mode 100644 orm/__pycache__/fields.cpython-38.pyc delete mode 100644 orm/__pycache__/models.cpython-38.pyc mode change 100755 => 100644 scripts/test.sh delete mode 100644 tests/__pycache__/__init__.cpython-38.pyc delete mode 100644 tests/__pycache__/test_columns.cpython-38-pytest-6.0.1.pyc delete mode 100644 tests/__pycache__/test_columns.cpython-38.pyc diff --git a/.gitignore b/.gitignore index d33695b..33d394f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ p38venv -.idea \ No newline at end of file +.idea +.pytest_cache +*.pyc +*.log \ No newline at end of file diff --git a/orm/__pycache__/__init__.cpython-38.pyc b/orm/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 4d8d4b7aaa909737d54fac940d004045fd209070..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 128 zcmWIL<>g`kf=RJz@gVv!h(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2DN%POXzC_gJT uxuiHIvA8lXSvS8ZHwJ{`<1_OzOXB183My}L*yQG?l;)(`fwX@HVg>*mlN&n# diff --git a/orm/__pycache__/exceptions.cpython-38.pyc b/orm/__pycache__/exceptions.cpython-38.pyc deleted file mode 100644 index 2515eefefe82c35a623377c3459a4582803970f7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 438 zcmb79!Ab)$6ihZ-yOtId`~;!&pdTP26)b|{dXQeqLRdBtRyLb8X~CZLXNrH6tAF9i zdAqbKc=5qZ-XxF7y!mJ}V5r;ou{1b8Nc^lAVhZ?U&^1#`@i#!TSFzyGkrL2CbLdFL z4{XvaF!`%E)D-YTkY)?67-l(Uqi}fX*L894%W zuJmJBmy~wqoOP)qyYD<**QdJ#)C0C3l1?Wg?C+7{B~Q_bIj7y{H0m(5>1T)bB<$Ugb5|TelW8O!Zo##LZ0Pd?%$#%1{^5H$duVU?X+{ zUUM>LhmNKrdPEO_%Pq}Dl(P+@og*f+JUq6h190Bz2G2SEE$=z%O!bYnIg_&~$Afz(&PAl0NFj1WlV zgpSD})2!=^P;$p)myOsDkoc~LkN~OxAv~%;H6VcS)5;SpB(~G8JR7HZbQTe7H*8GI zxsK3rT0*>Kk`-E+n%d8kJiRY^L*q|?GTVfpiS@8A{>%A}d;#2=D5)TUc^r@&sv#q+>Usj0?snnU-55HI3{!TtnDVXz2YgZ6Q7ZsOLwJFNJJa7F*N zffV*K4V?gL8?KBgZraWydzWF+8YV;hdJbH#!u$+F7pF=Wm16PG4&RKxDeCD3j_df2km_=)@09i*75Dib_Tne-@&^a9c7YvnKg%%wlSE7Z7 zAz3tFGO)nAaLFz*ZfmX0@QM+`mdoH#GXxwBXO?nTl-84F1tY}T7@nE^@OtgEX<8MylmeA1mo~KGy9xmc0tu->BIUA5S#5U4!Oq%oX4X`W z)~7`3iC(xOlpOOnIPfFp%89?gq2+mJ?IcZEYu@>u_nG(O8UMJj5Ho!Dj&G)aMvVPU zowHwn&KD@=2MEa|FW7)5oKssAA`z&)!W;OB@5cNh7=%g4*|$svGCW{1RE>QviO_Gz z2>nP!viYl*G^JQ)o%kP|%R1b8Uk)d0sN1<#S*g;&j#LHu%?FuL-}F==b6R=!3uhg3 z3k}5*E?FX!r+gKtPr*AJjq|+a&GRuq1m`-55L2j%+ z%m&H^o6u-%kd>u2Hq@#d>tXJn(cDIxlzS-VFNl)uvqN67idUlYj{K=#d51#sBd)Jk z{yM8XDc)ey0JVoYlx*7AV$-Pd``)2{#J*;aZ*Pcca~hXH6_??@&=0D(-_So*&1s{G z4x=hQY*x)9oXFYk1L@1)fSvB%=hIdh_ghs^guyjU() zyvwWBGmhRRoN%eXgcb@66Rz2Inr54()LCApsa;6ZGTST^*#%up4q!8xm06nR92cK| zWC;qDbrf?OqQ@%s_y+F8r=nz)C^_!rlYROE#T(s_tSf#KyQG&d!zXdi^oFL)hPl#Y zY#TxOWJG!!QgyR&QFc5lmhalQOW}|bs3az5rEcJu>C4#n?>pb!-x%qgzRF9pk(tRb zzq6wU8&q6SndQ;M#w)Uym-)_MwE8AtM(neYhdh2R*xs#|!(%y3hw4e1F0T|jc~+R! z<?C$ewZJtDdjph-Mgtcbdn8DFY*kuEEYI5-$ zlw)KvfK#>s75PbDA@LdnF_8HPf_Whq%>7>Ye$0Il@d(oAOFV{iP>Ya9FGBwcdi1-% zW3hMj|J~JTk{4#73q#>5<6;)ZyW^~w&G-~#)y3 z+vf=gwBMBRyys{wMi#Yf;0*ETmCPiq-;-7%N{JlY4)yDmI7Iu17>o7!3i7oty2w;8 z14})LRPKm1D@NPMO>eX_($}$BzXf5#U$SDX4ARjzs_0rj(_3K5?I%>%&OPxk`JgxK zArtOmt8ttSr!MlnnuL*D^>jYB4VF1(B@&=9hZ@Ym`iyXw)Dpxb`yOYd2j+FWawO;h zT4%?6trO_CvC%dR*a|GJ-$zehr6FX`Y!Gqp=?`ekqeKkcowW1p33z=U^Bm)7MU}aBY5teY}A43mkE0a0mPZKS<#m{>3z%4fZz2 zcqY@1s4w9Y&76n73LJEW8T=*Q=Cf=8{Fl#pDQJk=3IA8LmzWC_jvyh!6K4GhhG4zV zP}R1*P6ON_Ip=?U4^y9`7-uB07uDt)ZNFkQa=6n-%y6o);^b&@?OgAtm_MVpwt!xt zQRli-|B-TcWUai@(j8hwd7$4V(IW9N39^$WgV_kL2N=DaIP5lrpG&3nd$i#6{8J8~ zz`lcG79ohQqM(hR;IZd(O|$erhIq67w-MVAoYNa_Il|G`upo)%d|5MgjMr+p=4bb4 jXnIo}rMUCXz&N?9jaF)gTP5!qBGoqdvn|@*;*I|RFZ@LF diff --git a/scripts/test.sh b/scripts/test.sh old mode 100755 new mode 100644 diff --git a/tests/__pycache__/__init__.cpython-38.pyc b/tests/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index b1f8f4d9f4595a84cd83d906b354bf5ad18de23c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 130 zcmWIL<>g`kg4wZZ@gVv!h(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o2DV%POXzC_gJT yxuiHIvA8lXSvS8ZH>M=D7|4&0&&?$ zMoo%+p*JrGVfjY&ZD+iKbtp`@7$^3<$2fumR4z}0WBSqKVx&5HeI zF(~S>9haKrpiGHyH!3|69QcUv9M3-{JTJ|hLdhr7VLnm@=rUU*D`L)_;@cpv5^ zUWWG(PH&THWee;g^HrK1x!K|(zAhG{Bub-Jazo0Ng!#OlZ=l%d&jw}=9(5jw&`t#w zr-3cNA>k~$+~W2V8h9`+pn)*a^jVSQxhGaXmU&Ds za6b1SnuV-98qGQgnQ+!K2Thx*-`TXX97aPruz{Ja*QE=LMC<0IkHev(W^@GSjsiIZ z&xjWls%Lk2fm%@A2Pk^WA|!=_uJdnoNrx=cF6Fdqb!f+0vATB0R^O)9hRt&wdy(=y z&M#WM9Cy;(M!rW@9Q7%zI~~%a@(UQbI7%#7$1&|u{kD3PyFb|Dq?^Md-pNH2#`)BP z)j}uVqmSq}P-;2SrSRCWKUwy{5^L0wr&NuD{J<7}A6tHfbz{r#df%5~3omvEqy5+v zcUkn)h(-HE2W(CTHa(?s?Il}P}2lD22El%K^kIipUI9_rLNY&CWiAQsjZd*yB?WpD9v}L2!Y==^0C0J$)t5KXri2@95 zfrQGAF?FxC8gnhmmg3ez7;8w)3QUEI0ySin_jD8ncR+TSEdtPmDHCh$I7~vk9dzvF zDedXN;;N^^niX+t6UnJj=He(bZ5X5Z(kOE!7*Q}e%G5_d3h^mBp#uOHfrxdiqlMK} zu&~h(N*xCQ^&t=f$m;+Mc<{`V5~+{@?+kd$=}S1|mXUy&9C6^A!Hk>%O*>Ez(R7G< zzrf-Af1)1uN07$fg*30}ntnD8|2C*(`RATG{1+m=z(EAGYN{}A;MUrO`xRY+(b zUoxw`VvhI;n1tdz`wtM%UfEIqx(MVOsO&8yCxL9irAXxKK&nMk!sR5& z`__nMs%BUvyJ|#Uf$w%y1Ojvc9^D6^Gb!1CyShie8>a-?Qn-7eEwxhnk=-RZ1BO|U{(RRF%Go^z5OfjRQXfbG`gO=p>Xj4%6Wrfyw zS$*gNF3}AZuWPYDhOv2Nr5 zt!fmF9*xg-dvrY>q!ix8Fey%ntP>*3#eXZ6 dPl Date: Mon, 3 Aug 2020 08:22:20 +0200 Subject: [PATCH 09/62] make script executable --- scripts/test.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 scripts/test.sh diff --git a/scripts/test.sh b/scripts/test.sh old mode 100644 new mode 100755 From f27e69f87f8c9e96b06f34b65e7ec8154bfa2d75 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 08:32:42 +0200 Subject: [PATCH 10/62] update readme badges --- README.md | 4 ++-- requirements.txt | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d0c2e2a..f3e6026 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@

- Build Status + Build Status - Coverage + Coverage Package version diff --git a/requirements.txt b/requirements.txt index 7a4f7a2..590566c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ sqlalchemy # Testing pytest -pytest-cov \ No newline at end of file +pytest-cov +codecov \ No newline at end of file From 0bd964bdc41ef9c9f7da1a2898f0b7b9f6bd793b Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 13:06:52 +0200 Subject: [PATCH 11/62] add other valid field types, better parse model fields to pydantic model with optional values --- .codecov.yml | 9 +++++- .coverage | Bin 53248 -> 53248 bytes README.md | 7 ++-- orm/fields.py | 73 +++++++++++++++++++++++++++++++++++++++++- orm/models.py | 17 +++++++--- tests/test_columns.py | 29 +++++++++++++---- 6 files changed, 120 insertions(+), 15 deletions(-) diff --git a/.codecov.yml b/.codecov.yml index 033aafd..6d50415 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -8,4 +8,11 @@ coverage: patch: yes changes: yes -comment: off \ No newline at end of file +comment: + layout: "reach, diff, flags, files" + behavior: default + require_changes: false # if true: only post the comment if coverage changes + require_base: no # [yes :: must have a base report to post] + require_head: yes # [yes :: must have a head report to post] + branches: # branch names that can post comment + - "master" \ No newline at end of file diff --git a/.coverage b/.coverage index 922eaa1c16bc5e9ad0f3557871e19de633a21510..edf7fc32cfda3c74535b9c708808ecd0f42a0c34 100644 GIT binary patch delta 175 zcmZozz}&Eac>|jQmjDC5H@_9X9KXO~K>_~B68cIChQ?L~rdFng{7lT!;$^8t#hLke zrg}zthK)vyEV50JlcV$_n2XEGHm}u>4B%VG!2gr~CI3VI|jQmjnZUFuxPO8o$J5K>-f_$pZRH3I-NdhGteq#{5jo(&A;QMa7x< zd8T?sdWMZ!j4ZNEo|A+0BQ|f+4+-Gg#lZiQ|26+({!9F)_;&$KSjKP9$Hv0QDaRI7 zSGNCm>ifOcY^*>&A8Q&v3y@Y|IrXn}`~Uf8?&UDEF@q$SnRu9hv=md>&zgJB&dP88 I+0X6(088;M6951J diff --git a/README.md b/README.md index f3e6026..2c6b6d0 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# ORM +# Async-ORM

@@ -10,6 +10,7 @@ Package version +CodeFactor

The `async-orm` package is an async ORM for Python, with support for Postgres, @@ -156,7 +157,7 @@ assert len(tracks) == 1 The following keyword arguments are supported on all field types. * `primary_key` -* `allow_null` +* `nullable` * `default` * `index` * `unique` @@ -167,7 +168,7 @@ All fields are required unless one of the following is set: * `allow_blank` - Allow empty strings to validate. Sets the default to `""`. * `default` - Set a default value for the field. -* `orm.String(max_length)` +* `orm.String(length)` * `orm.Text()` * `orm.Boolean()` * `orm.Integer()` diff --git a/orm/fields.py b/orm/fields.py index 221c9a2..1b64610 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,3 +1,7 @@ +import datetime +import decimal + +import pydantic import sqlalchemy from orm.exceptions import ModelDefinitionError @@ -66,8 +70,75 @@ class String(BaseField): class Integer(BaseField): __type__ = int + def get_column_type(self): + return sqlalchemy.Integer() + + +class Text(BaseField): + __type__ = str + + def get_column_type(self): + return sqlalchemy.Text() + + +class Float(BaseField): + __type__ = float + + def get_column_type(self): + return sqlalchemy.Float() + + +class Boolean(BaseField): + __type__ = bool + + def get_column_type(self): + return sqlalchemy.Boolean() + + +class DateTime(BaseField): + __type__ = datetime.datetime + + def get_column_type(self): + return sqlalchemy.DateTime() + + +class Date(BaseField): + __type__ = datetime.date + + def get_column_type(self): + return sqlalchemy.Date() + + +class Time(BaseField): + __type__ = datetime.time + + def get_column_type(self): + return sqlalchemy.Time() + + +class JSON(BaseField): + __type__ = pydantic.Json + + def get_column_type(self): + return sqlalchemy.JSON() + + +class BigInteger(BaseField): + __type__ = int + + def get_column_type(self): + return sqlalchemy.BigInteger() + + +class Decimal(BaseField): + __type__ = decimal.Decimal + def __init__(self, *args, **kwargs): + assert 'precision' in kwargs, 'precision is required' + assert 'length' in kwargs, 'length is required' + self.length = kwargs.pop('length') + self.precision = kwargs.pop('precision') super().__init__(*args, **kwargs) def get_column_type(self): - return sqlalchemy.Integer() + return sqlalchemy.DECIMAL(self.length, self.precision) diff --git a/orm/models.py b/orm/models.py index 605b23a..50f2eb9 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional import sqlalchemy from pydantic import create_model @@ -6,6 +6,17 @@ from pydantic import create_model from orm.fields import BaseField +def parse_pydantic_field_from_model_fields(object_dict: dict): + pydantic_fields = {field_name: ( + base_field.__type__, + ... if (not base_field.nullable and not base_field.default) else ( + base_field.default() if callable(base_field.default) else base_field.default) + ) + for field_name, base_field in object_dict.items() + if isinstance(base_field, BaseField)} + return pydantic_fields + + class ModelMetaclass(type): def __new__( mcs: type, name: str, bases: Any, attrs: dict @@ -28,9 +39,7 @@ class ModelMetaclass(type): pkname = field_name columns.append(field.get_column(field_name)) - pydantic_fields = {field_name: (base_field.__type__, base_field.default or ...) - for field_name, base_field in new_model.__dict__.items() - if isinstance(base_field, BaseField)} + pydantic_fields = parse_pydantic_field_from_model_fields(new_model.__dict__) new_model.__table__ = sqlalchemy.Table(tablename, metadata, *columns) new_model.__columns__ = columns diff --git a/tests/test_columns.py b/tests/test_columns.py index c687ace..04cc8fa 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,3 +1,6 @@ +import datetime + +import pydantic import pytest import sqlalchemy @@ -12,20 +15,34 @@ class ExampleModel(Model): __tablename__ = "example" __metadata__ = metadata test = fields.Integer(primary_key=True) - test2 = fields.String(length=250) + test_string = fields.String(length=250) + test_text = fields.Text() + test_bool = fields.Boolean(nullable=False) + test_float = fields.Float() + test_datetime = fields.DateTime(default=datetime.datetime.now) + test_date = fields.Date(default=datetime.date.today) + test_time = fields.Time(default=datetime.time) + test_json = fields.JSON(default={}) + test_bigint = fields.BigInteger(default=0) + test_decimal = fields.Decimal(length=10, precision=2) class ExampleModel2(Model): __tablename__ = "example2" __metadata__ = metadata test = fields.Integer(name='test12', primary_key=True) - test2 = fields.String('test22', length=250) + test_string = fields.String('test_string2', length=250) + + +def test_not_nullable_field_is_required(): + with pytest.raises(pydantic.error_wrappers.ValidationError): + ExampleModel(test=1, test_string='test') def test_model_attribute_access(): - example = ExampleModel(test=1, test2='test') + example = ExampleModel(test=1, test_string='test', test_bool=True) assert example.test == 1 - assert example.test2 == 'test' + assert example.test_string == 'test' example.test = 12 assert example.test == 12 @@ -35,7 +52,7 @@ def test_model_attribute_access(): def test_primary_key_access_and_setting(): - example = ExampleModel(pk=1, test2='test') + example = ExampleModel(pk=1, test_string='test', test_bool=True) assert example.pk == 1 example.pk = 2 @@ -49,4 +66,4 @@ def test_wrong_model_definition(): __tablename__ = "example3" __metadata__ = metadata test = fields.Integer(name='test12', primary_key=True) - test2 = fields.String('test22', name='test22', length=250) + test_string = fields.String('test_string2', name='test_string2', length=250) From 612f8d4604bdb8fb576f9f6b03b83fa5937dd45e Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 13:11:45 +0200 Subject: [PATCH 12/62] added test for basic default values, optional fields etc. --- .coverage | Bin 53248 -> 53248 bytes tests/test_columns.py | 8 +++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.coverage b/.coverage index edf7fc32cfda3c74535b9c708808ecd0f42a0c34..8755b3052e06a04c4b76a6414d65aa0ae7b36dcc 100644 GIT binary patch delta 67 zcmV-J0KETzpaX!Q1F$tO1u-!?F)*__FU(L5tpE@C59klz55^C^53REi5Sb5?u8%tr Z4*~=M2_OQx%m2^k{BQH-53~7?2S6=g8iW7< delta 66 zcmV-I0KNZ!paX!Q1F$tO1u!)_H8Ha~FU(L5t^g1D59tr#562I`53aKj5SkB@t&ckp Y4gv%L2_6Ev%m2^kyv>_Gv-yt)KoPbXo&W#< diff --git a/tests/test_columns.py b/tests/test_columns.py index 04cc8fa..981d9b7 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -16,7 +16,7 @@ class ExampleModel(Model): __metadata__ = metadata test = fields.Integer(primary_key=True) test_string = fields.String(length=250) - test_text = fields.Text() + test_text = fields.Text(default='') test_bool = fields.Boolean(nullable=False) test_float = fields.Float() test_datetime = fields.DateTime(default=datetime.datetime.now) @@ -43,6 +43,12 @@ def test_model_attribute_access(): example = ExampleModel(test=1, test_string='test', test_bool=True) assert example.test == 1 assert example.test_string == 'test' + assert example.test_datetime.year == datetime.datetime.now().year + assert example.test_date == datetime.date.today() + assert example.test_text == '' + assert example.test_float is None + assert example.test_bigint == 0 + assert example.test_json == {} example.test = 12 assert example.test == 12 From 0e5d73e7dc97797662c38ffb3b08e41932016e04 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 13:16:50 +0200 Subject: [PATCH 13/62] changed package name in test to cover orm catalog --- scripts/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test.sh b/scripts/test.sh index 577f0e9..11c09f6 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,6 +1,6 @@ #!/bin/sh -e -PACKAGE="async-orm" +PACKAGE="orm" PREFIX="" if [ -d 'venv' ] ; then From 8f8e5db2f818f21051cda3ab1b767a61d633b35e Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 13:23:47 +0200 Subject: [PATCH 14/62] update column types in readme --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 2c6b6d0..a25401f 100644 --- a/README.md +++ b/README.md @@ -7,10 +7,9 @@ Coverage - - Package version + +CodeFactor -CodeFactor

The `async-orm` package is an async ORM for Python, with support for Postgres, @@ -164,10 +163,10 @@ The following keyword arguments are supported on all field types. All fields are required unless one of the following is set: -* `allow_null` - Creates a nullable column. Sets the default to `None`. -* `allow_blank` - Allow empty strings to validate. Sets the default to `""`. +* `nullable` - Creates a nullable column. Sets the default to `None`. * `default` - Set a default value for the field. +Available Model Fields: * `orm.String(length)` * `orm.Text()` * `orm.Boolean()` @@ -177,6 +176,8 @@ All fields are required unless one of the following is set: * `orm.Time()` * `orm.DateTime()` * `orm.JSON()` +* `orm.BigInteger()` +* `orm.Decimal(lenght, precision)` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases From 876f225d0b6f9ccf39ff1ae1a84c505318f76b1c Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 13:25:40 +0200 Subject: [PATCH 15/62] change badges linkt to proper repo --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a25401f..6145f7f 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # Async-ORM

- + Build Status - + Coverage From d7355b8c9bb2c5cfeb3d00d1f4dbd383602c9360 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 17:49:01 +0200 Subject: [PATCH 16/62] more checks for table and pydantic model creation --- .coverage | Bin 53248 -> 53248 bytes orm/__init__.py | 17 +++++++++++++++++ orm/fields.py | 9 +++++++++ orm/models.py | 34 ++++++++++++++++++++++------------ tests/test_columns.py | 29 +++++++++++++++++++++++++++-- 5 files changed, 75 insertions(+), 14 deletions(-) diff --git a/.coverage b/.coverage index 8755b3052e06a04c4b76a6414d65aa0ae7b36dcc..78c44140c6601a6f7c8fb9b99c03e62bea51dd59 100644 GIT binary patch delta 132 zcmZozz}&Ead4sV&ySbH#g_WW4W-I+O0enjt_Q@8;jvdEfs}zB&7B+UA=%X0y-MF>^2j krG%Kj{bv+n0jDH{h+RcIjGx>e^*;p7k z71*Noect^0>*;&dn}7DRI{*NOmo4Q0 diff --git a/orm/__init__.py b/orm/__init__.py index e69de29..5270355 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -0,0 +1,17 @@ +from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float +from orm.models import Model + +__all__ = [ + "Integer", + "BigInteger", + "Boolean", + "Time", + "Text", + "String", + "JSON", + "DateTime", + "Date", + "Decimal", + "Float", + "Model" +] diff --git a/orm/fields.py b/orm/fields.py index 1b64610..3f33a81 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,5 +1,6 @@ import datetime import decimal +from typing import Any import pydantic import sqlalchemy @@ -10,6 +11,10 @@ from orm.exceptions import ModelDefinitionError class BaseField: __type__ = None + def __new__(cls, *args, **kwargs): + cls.__annotations__ = {} + return super().__new__(cls) + def __init__(self, *args, **kwargs): name = kwargs.pop('name', None) args = list(args) @@ -32,6 +37,10 @@ class BaseField: self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) + self.pydantic_only = kwargs.pop('pydantic_only', False) + if self.pydantic_only and self.primary_key: + raise ModelDefinitionError('Primary key column cannot be pydantic only.') + def get_column(self, name=None) -> sqlalchemy.Column: name = self.name or name constraints = self.get_constraints() diff --git a/orm/models.py b/orm/models.py index 50f2eb9..b446885 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import sqlalchemy from pydantic import create_model @@ -33,20 +33,29 @@ class ModelMetaclass(type): pkname = None columns = [] - for field_name, field in new_model.__dict__.items(): - if isinstance(field, BaseField): + for field_name, field in attrs.items(): + if isinstance(field, BaseField) and not field.pydantic_only: if field.primary_key: pkname = field_name columns.append(field.get_column(field_name)) - pydantic_fields = parse_pydantic_field_from_model_fields(new_model.__dict__) + # sqlalchemy table creation + attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) + attrs['__columns__'] = columns + attrs['__pkname__'] = pkname - new_model.__table__ = sqlalchemy.Table(tablename, metadata, *columns) - new_model.__columns__ = columns - new_model.__pkname__ = pkname - new_model.__pydantic_fields__ = pydantic_fields - new_model.__pydantic_model__ = create_model(name, **pydantic_fields) - new_model.__fields__ = new_model.__pydantic_model__.__fields__ + # pydantic model creation + pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + pydantic_model = create_model(name, **pydantic_fields) + attrs['__pydantic_fields__'] = pydantic_fields + attrs['__pydantic_model__'] = pydantic_model + attrs['__fields__'] = pydantic_model.__fields__ + attrs['__signature__'] = pydantic_model.__signature__ + attrs['__annotations__'] = pydantic_model.__annotations__ + + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) return new_model @@ -62,9 +71,10 @@ class Model(metaclass=ModelMetaclass): def __setattr__(self, key, value): if key in self.__fields__: setattr(self.values, key, value) - super().__setattr__(key, value) + else: + super().__setattr__(key, value) - def __getattribute__(self, item): + def __getattribute__(self, item) -> Any: if item != '__fields__' and item in self.__fields__: return getattr(self.values, item) return super().__getattribute__(item) diff --git a/tests/test_columns.py b/tests/test_columns.py index 981d9b7..7719a0c 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,4 +1,5 @@ import datetime +from typing import ClassVar import pydantic import pytest @@ -27,6 +28,10 @@ class ExampleModel(Model): test_decimal = fields.Decimal(length=10, precision=2) +fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', + 'test_bigint', 'test_json'] + + class ExampleModel2(Model): __tablename__ = "example2" __metadata__ = metadata @@ -66,10 +71,30 @@ def test_primary_key_access_and_setting(): assert example.test == 2 -def test_wrong_model_definition(): +def test_pydantic_model_is_created(): + example = ExampleModel(pk=1, test_string='test', test_bool=True) + assert issubclass(example.values.__class__, pydantic.BaseModel) + assert all([field in example.values.__fields__ for field in fields_to_check]) + assert example.values.test == 1 + + +def test_sqlalchemy_table_is_created(): + example = ExampleModel(pk=1, test_string='test', test_bool=True) + assert issubclass(example.__table__.__class__, sqlalchemy.Table) + assert all([field in example.__table__.columns for field in fields_to_check]) + + +def test_double_column_name_in_model_definition(): with pytest.raises(ModelDefinitionError): class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True) test_string = fields.String('test_string2', name='test_string2', length=250) + + +def test_setting_pk_column_as_pydantic_only_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) From e0bb7e2cda0f50e9e2b7763dd95b10dc4fb32d1f Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 19:59:04 +0200 Subject: [PATCH 17/62] added basic save, update, load and delate methods --- .coverage | Bin 53248 -> 53248 bytes .gitignore | 3 +- orm/exceptions.py | 12 ++++ orm/fields.py | 14 ++-- orm/helpers.py | 27 +++++++ orm/models.py | 124 +++++++++++++++++++++++++++++---- requirements.txt | 3 +- tests/settings.py | 3 + tests/test_columns.py | 122 +++++++++++--------------------- tests/test_model_definition.py | 114 ++++++++++++++++++++++++++++++ 10 files changed, 316 insertions(+), 106 deletions(-) create mode 100644 orm/helpers.py create mode 100644 tests/settings.py create mode 100644 tests/test_model_definition.py diff --git a/.coverage b/.coverage index 78c44140c6601a6f7c8fb9b99c03e62bea51dd59..5710bb4dace601b758eec3329f156f463a9c90a0 100644 GIT binary patch delta 526 zcmZozz}&Ead4sV&yQP(>g_W_zW-I+u0%Dwe1q}RG_?Pfk@dxp%@%`hw#5a?#V6&h= z7$2)TCo@Cn`1#H~KsU;Nb^ zSTvYJ8PTlcv|ug;iQ%-4Lz_7orhttH$vXDU|KnH{ML78D8Tfzlzv6$ue~f<*|6Km5 z{PjTF()sH}I9M1twK&SA|DXSJ&gc8jo>|)4+q1K=14a4Rxn$XZv<=&?+3(N1`B!{= zzgqSDUn}q4eJA_k@Ata<>>R8>X(86MG&TtqAj^m)@8sWbPX5av)!+C(^1tDK&i{!2E>Op1p!%Ksy3D}X f;$;R#4=*z?3V4~>Km;p@U;z=#Ac6@*Faikxx)Pta delta 273 zcmZozz}&Ead4sV&ySbH#g_WW4W-I+u0s?G&{}}kM@Gs%7;t%3iiofc|M{=-F9k}y=G*)~j#W{BjejWv|4;s>{I~fJ@b3UiP2|rPWMg6E zRAGzS_j&Vg%lgTCU*Bb9V+9KFv8G9~0BIAJyqkYt=Y9V_`R44iX`656n9V+0$IQVD zloDe8_McIR3CL1oDx3a){?9r8bMMv6n=jAG$im6V$Hc|Jz`(|TgMt4y{}=wZ{LlFB z^WOlv=LrAi8}oG(fbMw&lKRH~k^c? None: name = kwargs.pop('name', None) args = list(args) if args: @@ -28,7 +24,7 @@ class BaseField: self.name = name self.primary_key = kwargs.pop('primary_key', False) - self.autoincrement = kwargs.pop('autoincrement', 'auto') + self.autoincrement = kwargs.pop('autoincrement', self.primary_key) self.nullable = kwargs.pop('nullable', not self.primary_key) self.default = kwargs.pop('default', None) @@ -41,7 +37,7 @@ class BaseField: if self.pydantic_only and self.primary_key: raise ModelDefinitionError('Primary key column cannot be pydantic only.') - def get_column(self, name=None) -> sqlalchemy.Column: + def get_column(self, name: str = None) -> sqlalchemy.Column: name = self.name or name constraints = self.get_constraints() return sqlalchemy.Column( @@ -60,7 +56,7 @@ class BaseField: def get_column_type(self) -> sqlalchemy.types.TypeEngine: raise NotImplementedError() # pragma: no cover - def get_constraints(self): + def get_constraints(self) -> Optional[List]: return [] diff --git a/orm/helpers.py b/orm/helpers.py new file mode 100644 index 0000000..6e3d254 --- /dev/null +++ b/orm/helpers.py @@ -0,0 +1,27 @@ +from typing import Union, Set, Dict # pragma no cover + + +class Excludable: # pragma no cover + + @staticmethod + def get_excluded(exclude: Union[Set, Dict, None], key: str = None): + # print(f'checking excluded for {key}', exclude) + if isinstance(exclude, dict): + if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys(): + return exclude.get(key).get('__all__') + return exclude.get(key, {}) + return exclude + + @staticmethod + def is_excluded(exclude: Union[Set, Dict, None], key: str = None): + if exclude is None: + return False + to_exclude = Excludable.get_excluded(exclude, key) + # print(f'to exclude for current key = {key}', to_exclude) + + if isinstance(to_exclude, Set): + return key in to_exclude + elif to_exclude is ...: + return True + else: + return False diff --git a/orm/models.py b/orm/models.py index b446885..19224aa 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,15 +1,21 @@ +from __future__ import annotations + +import json from typing import Any +from typing import Set, Dict +import pydantic import sqlalchemy -from pydantic import create_model +from pydantic import BaseConfig, create_model +from orm.exceptions import ModelDefinitionError from orm.fields import BaseField def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( base_field.__type__, - ... if (not base_field.nullable and not base_field.default) else ( + ... if (not base_field.nullable and not base_field.default and not base_field.primary_key) else ( base_field.default() if callable(base_field.default) else base_field.default) ) for field_name, base_field in object_dict.items() @@ -33,26 +39,37 @@ class ModelMetaclass(type): pkname = None columns = [] + model_fields = {} for field_name, field in attrs.items(): - if isinstance(field, BaseField) and not field.pydantic_only: - if field.primary_key: - pkname = field_name - columns.append(field.get_column(field_name)) + if isinstance(field, BaseField): + model_fields[field_name] = field + if not field.pydantic_only: + if field.primary_key: + pkname = field_name + columns.append(field.get_column(field_name)) # sqlalchemy table creation attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname + if not pkname: + raise ModelDefinitionError( + 'Table has to have a primary key.' + ) + # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model(name, **pydantic_fields) + config = type('Config', (BaseConfig,), {'orm_mode': True}) + pydantic_model = create_model(name, __config__=config, **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model attrs['__fields__'] = pydantic_model.__fields__ attrs['__signature__'] = pydantic_model.__signature__ attrs['__annotations__'] = pydantic_model.__annotations__ + attrs['__model_fields__'] = model_fields + new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) @@ -63,21 +80,36 @@ class ModelMetaclass(type): class Model(metaclass=ModelMetaclass): __abstract__ = True - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") self.values = self.__pydantic_model__(**kwargs) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: + if self.is_conversion_to_json_needed(key) and not isinstance(value, str): + try: + value = json.dumps(value) + except TypeError: # pragma no cover + pass setattr(self.values, key, value) else: super().__setattr__(key, value) - def __getattribute__(self, item) -> Any: - if item != '__fields__' and item in self.__fields__: - return getattr(self.values, item) - return super().__getattribute__(item) + def __getattribute__(self, key: str) -> Any: + if key != '__fields__' and key in self.__fields__: + item = getattr(self.values, key) + if self.is_conversion_to_json_needed(key) and isinstance(item, str): + try: + item = json.loads(item) + except TypeError: # pragma no cover + pass + return item + + return super().__getattribute__(key) + + def is_conversion_to_json_needed(self, column_name: str) -> bool: + return self.__model_fields__.get(column_name).__type__ == pydantic.Json @property def pk(self): @@ -86,3 +118,69 @@ class Model(metaclass=ModelMetaclass): @pk.setter def pk(self, value): setattr(self.values, self.__pkname__, value) + + @property + def pk_column(self) -> sqlalchemy.Column: + return self.__table__.primary_key.columns.values()[0] + + def dict(self) -> Dict: + return self.values.dict() + + def from_dict(self, value_dict: Dict) -> None: + for key, value in value_dict.items(): + setattr(self, key, value) + + def extract_own_model_fields(self) -> Dict: + related_names = self.extract_related_names() + self_fields = {k: v for k, v in self.dict().items() if k not in related_names} + return self_fields + + @classmethod + def extract_related_names(cls) -> Set: + related_names = set() + # for name, field in cls.__fields__.items(): + # if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): + # related_names.add(name) + # elif field.sub_fields and any( + # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): + # related_names.add(name) + return related_names + + def extract_model_db_fields(self) -> Dict: + self_fields = self.extract_own_model_fields() + self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} + return self_fields + + async def save(self) -> int: + self_fields = self.extract_model_db_fields() + if self.__model_fields__.get(self.__pkname__).autoincrement: + self_fields.pop(self.__pkname__, None) + expr = self.__table__.insert() + expr = expr.values(**self_fields) + item_id = await self.__database__.execute(expr) + setattr(self, 'pk', item_id) + return item_id + + async def update(self, **kwargs: Any) -> int: + if kwargs: + new_values = {**self.dict(), **kwargs} + self.from_dict(new_values) + + self_fields = self.extract_model_db_fields() + self_fields.pop(self.__pkname__) + expr = self.__table__.update().values(**self_fields).where( + self.pk_column == getattr(self, self.__pkname__)) + result = await self.__database__.execute(expr) + return result + + async def delete(self) -> int: + expr = self.__table__.delete() + expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) + result = await self.__database__.execute(expr) + return result + + async def load(self) -> Model: + expr = self.__table__.select().where(self.pk_column == self.pk) + row = await self.__database__.fetch_one(expr) + self.from_dict(dict(row)) + return self diff --git a/requirements.txt b/requirements.txt index 590566c..f6280a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ sqlalchemy # Testing pytest pytest-cov -codecov \ No newline at end of file +codecov +pytest-asyncio \ No newline at end of file diff --git a/tests/settings.py b/tests/settings.py new file mode 100644 index 0000000..697acb0 --- /dev/null +++ b/tests/settings.py @@ -0,0 +1,3 @@ +import os + +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///test.db") diff --git a/tests/test_columns.py b/tests/test_columns.py index 7719a0c..edee116 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,100 +1,58 @@ import datetime -from typing import ClassVar -import pydantic +import databases import pytest import sqlalchemy -import orm.fields as fields -from orm.exceptions import ModelDefinitionError -from orm.models import Model +import orm +from tests.settings import DATABASE_URL +database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() -class ExampleModel(Model): +def time(): + return datetime.datetime.now().time() + + +class Example(orm.Model): __tablename__ = "example" __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) - test_text = fields.Text(default='') - test_bool = fields.Boolean(nullable=False) - test_float = fields.Float() - test_datetime = fields.DateTime(default=datetime.datetime.now) - test_date = fields.Date(default=datetime.date.today) - test_time = fields.Time(default=datetime.time) - test_json = fields.JSON(default={}) - test_bigint = fields.BigInteger(default=0) - test_decimal = fields.Decimal(length=10, precision=2) + __database__ = database + + id = orm.Integer(primary_key=True) + created = orm.DateTime(default=datetime.datetime.now) + created_day = orm.Date(default=datetime.date.today) + created_time = orm.Time(default=time) + description = orm.Text(nullable=True) + value = orm.Float(nullable=True) + data = orm.JSON(default={}) -fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', - 'test_bigint', 'test_json'] +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) -class ExampleModel2(Model): - __tablename__ = "example2" - __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True) - test_string = fields.String('test_string2', length=250) +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + example = Example() + await example.save() + await example.load() + assert example.created.year == datetime.datetime.now().year + assert example.created_day == datetime.date.today() + assert example.description is None + assert example.value is None + assert example.data == {} -def test_not_nullable_field_is_required(): - with pytest.raises(pydantic.error_wrappers.ValidationError): - ExampleModel(test=1, test_string='test') + await example.update(data={"foo": 123}, value=123.456) + await example.load() + assert example.value == 123.456 + assert example.data == {"foo": 123} - -def test_model_attribute_access(): - example = ExampleModel(test=1, test_string='test', test_bool=True) - assert example.test == 1 - assert example.test_string == 'test' - assert example.test_datetime.year == datetime.datetime.now().year - assert example.test_date == datetime.date.today() - assert example.test_text == '' - assert example.test_float is None - assert example.test_bigint == 0 - assert example.test_json == {} - - example.test = 12 - assert example.test == 12 - - example.new_attr = 12 - assert 'new_attr' in example.__dict__ - - -def test_primary_key_access_and_setting(): - example = ExampleModel(pk=1, test_string='test', test_bool=True) - assert example.pk == 1 - example.pk = 2 - - assert example.pk == 2 - assert example.test == 2 - - -def test_pydantic_model_is_created(): - example = ExampleModel(pk=1, test_string='test', test_bool=True) - assert issubclass(example.values.__class__, pydantic.BaseModel) - assert all([field in example.values.__fields__ for field in fields_to_check]) - assert example.values.test == 1 - - -def test_sqlalchemy_table_is_created(): - example = ExampleModel(pk=1, test_string='test', test_bool=True) - assert issubclass(example.__table__.__class__, sqlalchemy.Table) - assert all([field in example.__table__.columns for field in fields_to_check]) - - -def test_double_column_name_in_model_definition(): - with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - test_string = fields.String('test_string2', name='test_string2', length=250) - - -def test_setting_pk_column_as_pydantic_only_in_model_definition(): - with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) + await example.delete() diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py new file mode 100644 index 0000000..f06f141 --- /dev/null +++ b/tests/test_model_definition.py @@ -0,0 +1,114 @@ +import datetime +from typing import ClassVar + +import pydantic +import pytest +import sqlalchemy + +import orm.fields as fields +from orm.exceptions import ModelDefinitionError +from orm.models import Model + +metadata = sqlalchemy.MetaData() + + +class ExampleModel(Model): + __tablename__ = "example" + __metadata__ = metadata + test = fields.Integer(primary_key=True) + test_string = fields.String(length=250) + test_text = fields.Text(default='') + test_bool = fields.Boolean(nullable=False) + test_float = fields.Float() + test_datetime = fields.DateTime(default=datetime.datetime.now) + test_date = fields.Date(default=datetime.date.today) + test_time = fields.Time(default=datetime.time) + test_json = fields.JSON(default={}) + test_bigint = fields.BigInteger(default=0) + test_decimal = fields.Decimal(length=10, precision=2) + + +fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', + 'test_bigint', 'test_json'] + + +class ExampleModel2(Model): + __tablename__ = "example2" + __metadata__ = metadata + test = fields.Integer(name='test12', primary_key=True) + test_string = fields.String('test_string2', length=250) + + +@pytest.fixture() +def example(): + return ExampleModel(pk=1, test_string='test', test_bool=True) + + +def test_not_nullable_field_is_required(): + with pytest.raises(pydantic.error_wrappers.ValidationError): + ExampleModel(test=1, test_string='test') + + +def test_model_attribute_access(example): + assert example.test == 1 + assert example.test_string == 'test' + assert example.test_datetime.year == datetime.datetime.now().year + assert example.test_date == datetime.date.today() + assert example.test_text == '' + assert example.test_float is None + assert example.test_bigint == 0 + assert example.test_json == {} + + example.test = 12 + assert example.test == 12 + + example.new_attr = 12 + assert 'new_attr' in example.__dict__ + + +def test_primary_key_access_and_setting(example): + assert example.pk == 1 + example.pk = 2 + + assert example.pk == 2 + assert example.test == 2 + + +def test_pydantic_model_is_created(example): + assert issubclass(example.values.__class__, pydantic.BaseModel) + assert all([field in example.values.__fields__ for field in fields_to_check]) + assert example.values.test == 1 + + +def test_sqlalchemy_table_is_created(example): + assert issubclass(example.__table__.__class__, sqlalchemy.Table) + assert all([field in example.__table__.columns for field in fields_to_check]) + + +def test_double_column_name_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example3" + __metadata__ = metadata + test_string = fields.String('test_string2', name='test_string2', length=250) + + +def test_no_pk_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example3" + __metadata__ = metadata + test_string = fields.String(name='test_string2', length=250) + + +def test_setting_pk_column_as_pydantic_only_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) + + +def test_json_conversion_in_model(): + with pytest.raises(pydantic.ValidationError): + ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True) From a6f8fc6d7e612a486710fe4c48c4785d588568a2 Mon Sep 17 00:00:00 2001 From: collerek Date: Mon, 3 Aug 2020 20:05:57 +0200 Subject: [PATCH 18/62] remove unneeded future import --- .coverage | Bin 53248 -> 53248 bytes orm/models.py | 6 ++---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.coverage b/.coverage index 5710bb4dace601b758eec3329f156f463a9c90a0..0c9f6db6adb3236f1925954b21a689c4260c5301 100644 GIT binary patch delta 53 zcmV-50LuS>paX!Q1F$kL2r@7_Ff}?fF|#@^YG4@d<~u*T|J?Df$=;`aYInDL2YUa# L-}eEtypKdcNzWMf delta 53 zcmV-50LuS>paX!Q1F$kL2r)T2H8?smII}u0YG4?;oA1o*{&UB_CVQXysomY~9q9e{ Le%}YPypKdcV)hzw diff --git a/orm/models.py b/orm/models.py index 19224aa..096415e 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,7 +1,5 @@ -from __future__ import annotations - import json -from typing import Any +from typing import Any, Type from typing import Set, Dict import pydantic @@ -179,7 +177,7 @@ class Model(metaclass=ModelMetaclass): result = await self.__database__.execute(expr) return result - async def load(self) -> Model: + async def load(self) -> 'Model': expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) From 345fd227d1076958426cd4a32c327de601350834 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 4 Aug 2020 18:44:17 +0200 Subject: [PATCH 19/62] sloppy work on passing all of the test and reimplementing most of the features from encode --- .coverage | Bin 53248 -> 53248 bytes README.md | 4 + orm/__init__.py | 4 +- orm/exceptions.py | 6 +- orm/fields.py | 63 +++++++- orm/models.py | 319 ++++++++++++++++++++++++++++++++++--- orm/relations.py | 45 ++++++ tests/test_foreign_keys.py | 231 +++++++++++++++++++++++++++ tests/test_models.py | 0 9 files changed, 648 insertions(+), 24 deletions(-) create mode 100644 orm/relations.py create mode 100644 tests/test_foreign_keys.py create mode 100644 tests/test_models.py diff --git a/.coverage b/.coverage index 0c9f6db6adb3236f1925954b21a689c4260c5301..ba7b5f4a6b067d03e366db7864d285fedeb1619e 100644 GIT binary patch delta 742 zcmZ9K(MwZN9LMiGyLY?y?w+5fMNL5|aJ5h?CEa>hJ=jAJE|MNZ(j+z+&aF1hOltNQ z2;B!Cl;(qxPZEtaB5PE>I8cch6n35Zuu9rSE*xW%+quh$^>ohf`~7{-`TXEFI)z53 z(DO45TvN^27IRah`CN0r+z=_P(g_wG9k7T)m1>t9fj#rB5-4>L9kh%wu!*3^?4bLWuqw(u=LuFyY>$U<3@% zE^OizoQF?%2A*OsHe;S}zZb+iT_J?c>vDfVgk7BqNJEvLCA)W5I%Ps^T0okf$; z9iY{Ua+iVD7*U5}|4Ym4h4g{e?r@Eshe`^vum$T7h9Jy<4Q_x5PCzAr-NinJSF2de zmsD`%G;(*xWq)=u9gD|J7K=r5oF(dIHW^jm1TX|zdGdKo-|or%U|<-FWzQ4 zV(qG%SGQxcl9b&2p7Nzq8N@T>t)t8pzsnbD(9)zhMBDB5YBjlCRJ#H`DUp;ivbebB zJU_8$4}5X%Or(;tJ}pvO(UzR4f~a_{}v{I+$0YJMasB!Z+9iiNvF@1Pd@v!qWhVtfhd+HVTL=pCCq;>;3>4 Cc?atN delta 483 zcmZozz}&Ead4qvIm$8D8ft7)&m5JeIEB#XfVw`*h4E$I4m+)8d2l1=%{o}jDH2nT;X1OHF{SNsq7kMZx}pUXd$zaD65I)A+g2MZ&o7Dw6i|MP#&`F#J`GfR7W zdv-Q+l*d$nh zEF+e@oBQ9D{kNaF_hydS?6YZ`Z@vKrA4m=JxBra1OhA?l)2_37?CtNrXJut%;pCkB zub&O*x>F4NzxluLzvX|%f1m#b|3&^&K-X>J-+XGmm4X;2|7DQ$H~x?OZ}^|{KjOa& PRB{=pU?>0N%jb0g#WkQw diff --git a/README.md b/README.md index 6145f7f..5d99739 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,7 @@ The following keyword arguments are supported on all field types. * `primary_key` * `nullable` * `default` +* `server_default` * `index` * `unique` @@ -165,6 +166,9 @@ All fields are required unless one of the following is set: * `nullable` - Creates a nullable column. Sets the default to `None`. * `default` - Set a default value for the field. +* `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). +* `primary key` with `autoincrement` - When a column is set to primary key and autoincrement is set on this column. +Autoincrement is set by default on int primary keys. Available Model Fields: * `orm.String(length)` diff --git a/orm/__init__.py b/orm/__init__.py index 5270355..9c652ae 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -1,4 +1,5 @@ -from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float +from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float, \ + ForeignKey from orm.models import Model __all__ = [ @@ -13,5 +14,6 @@ __all__ = [ "Date", "Decimal", "Float", + "ForeignKey", "Model" ] diff --git a/orm/exceptions.py b/orm/exceptions.py index 7321d99..1a8c6d0 100644 --- a/orm/exceptions.py +++ b/orm/exceptions.py @@ -10,7 +10,11 @@ class ModelNotSet(AsyncOrmException): pass -class MultipleResults(AsyncOrmException): +class NoMatch(AsyncOrmException): + pass + + +class MultipleMatches(AsyncOrmException): pass diff --git a/orm/fields.py b/orm/fields.py index b42a901..4393bd1 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -6,6 +6,7 @@ import pydantic import sqlalchemy from orm.exceptions import ModelDefinitionError +from orm.relations import Relationship class BaseField: @@ -24,7 +25,7 @@ class BaseField: self.name = name self.primary_key = kwargs.pop('primary_key', False) - self.autoincrement = kwargs.pop('autoincrement', self.primary_key) + self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int) self.nullable = kwargs.pop('nullable', not self.primary_key) self.default = kwargs.pop('default', None) @@ -37,11 +38,30 @@ class BaseField: if self.pydantic_only and self.primary_key: raise ModelDefinitionError('Primary key column cannot be pydantic only.') + @property + def is_required(self): + return not self.nullable and not self.has_default and not self.is_auto_primary_key + + @property + def default_value(self): + default = self.default if self.default is not None else self.server_default + return default() if callable(default) else default + + @property + def has_default(self): + return self.default is not None or self.server_default is not None + + @property + def is_auto_primary_key(self): + if self.primary_key: + return self.autoincrement + return False + def get_column(self, name: str = None) -> sqlalchemy.Column: - name = self.name or name + self.name = self.name or name constraints = self.get_constraints() return sqlalchemy.Column( - name, + self.name, self.get_column_type(), *constraints, primary_key=self.primary_key, @@ -59,6 +79,9 @@ class BaseField: def get_constraints(self) -> Optional[List]: return [] + def expand_relationship(self, value, parent): + return value + class String(BaseField): __type__ = str @@ -147,3 +170,37 @@ class Decimal(BaseField): def get_column_type(self): return sqlalchemy.DECIMAL(self.length, self.precision) + + +class ForeignKey(BaseField): + def __init__(self, to, related_name: str = None, nullable: bool = False): + super().__init__(nullable=nullable) + self.related_name = related_name + self.to = to + + @property + def __type__(self): + return self.to.__pydantic_model__ + + def get_constraints(self): + fk_string = self.to.__tablename__ + "." + self.to.__pkname__ + return [sqlalchemy.schema.ForeignKey(fk_string)] + + def get_column_type(self): + to_column = self.to.__model_fields__[self.to.__pkname__] + return to_column.get_column_type() + + def expand_relationship(self, value, child): + if isinstance(value, self.to): + model = value + else: + model = self.to(**{self.to.__pkname__: value}) + + child_model_name = self.related_name or child.__class__.__name__.lower() + 's' + model._orm_relationship_manager.add( + Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) + model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name, + type_=child.__pydantic_model__, + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__) + return model diff --git a/orm/models.py b/orm/models.py index 096415e..e19474c 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,26 +1,250 @@ +import copy +import inspect import json -from typing import Any, Type +import uuid +from typing import Any, List, Type from typing import Set, Dict import pydantic import sqlalchemy from pydantic import BaseConfig, create_model -from orm.exceptions import ModelDefinitionError +from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch from orm.fields import BaseField +from orm.relations import RelationshipManager def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( base_field.__type__, - ... if (not base_field.nullable and not base_field.default and not base_field.primary_key) else ( - base_field.default() if callable(base_field.default) else base_field.default) + ... if base_field.is_required else base_field.default_value ) for field_name, base_field in object_dict.items() if isinstance(base_field, BaseField)} return pydantic_fields +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} + + +class QuerySet: + ESCAPE_CHARACTERS = ['%', '_'] + + def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, + limit_count: int = None, offset: int = None): + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + + def __get__(self, instance, owner): + return self.__class__(model_cls=owner) + + @property + def database(self): + return self.model_cls.__database__ + + @property + def table(self): + return self.model_cls.__table__ + + def build_select_expression(self): + tables = [self.table] + select_from = self.table + + for item in self._select_related: + model_cls = self.model_cls + select_from = self.table + for part in item.split("__"): + model_cls = model_cls.__model_fields__[part].to + select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) + tables.append(model_cls.__table__) + + expr = sqlalchemy.sql.select(tables) + expr = expr.select_from(select_from) + + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + # print(expr.compile(compile_kwargs={"literal_binds": True})) + return expr + + def filter(self, **kwargs): + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + + if kwargs.get("pk"): + pk_name = self.model_cls.__pkname__ + kwargs[pk_name] = kwargs.pop("pk") + + for key, value in kwargs.items(): + if "__" in key: + parts = key.split("__") + + # Determine if we should treat the final part as a + # filter operator or as a related field. + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + model_cls = self.model_cls + if related_parts: + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + for part in related_parts: + model_cls = model_cls.__model_fields__[part].to + + column = model_cls.__table__.columns[field_name] + + else: + op = "exact" + column = self.table.columns[key] + + # Map the operation code onto SQLAlchemy's ColumnElement + # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement + op_attr = FILTER_OPERATORS[op] + has_escaped_character = False + + if op in ["contains", "icontains"]: + has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS + if c in value) + if has_escaped_character: + # enable escape modifier + for char in self.ESCAPE_CHARACTERS: + value = value.replace(char, f'\\{char}') + value = f"%{value}%" + + if isinstance(value, Model): + value = value.pk + + clause = getattr(column, op_attr)(value) + clause.modifiers['escape'] = '\\' if has_escaped_character else None + filter_clauses.append(clause) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=filter_clauses, + select_related=select_related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + def select_related(self, related): + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + # async def exists(self) -> bool: + # expr = self.build_select_expression() + # expr = sqlalchemy.exists(expr).select() + # return await self.database.fetch_val(expr) + + def limit(self, limit_count: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=limit_count, + offset=self.query_offset + ) + + def offset(self, offset: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=offset + ) + + async def get(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).get() + + expr = self.build_select_expression().limit(2) + rows = await self.database.fetch_all(expr) + + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + return self.model_cls.from_row(rows[0], select_related=self._select_related) + + async def all(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).all() + + expr = self.build_select_expression() + rows = await self.database.fetch_all(expr) + return [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + + async def create(self, **kwargs): + + new_kwargs = dict(**kwargs) + + # Remove primary key when None to prevent not null constraint in postgresql. + pkname = self.model_cls.__pkname__ + pk = self.model_cls.__model_fields__[pkname] + if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): + del new_kwargs[pkname] + + # substitute related models with their pk + for field in self.model_cls.extract_related_names(): + if field in new_kwargs and new_kwargs.get(field) is not None: + new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) + + # Build the insert expression. + expr = self.table.insert() + expr = expr.values(**new_kwargs) + + # Execute the insert, and return a new model instance. + instance = self.model_cls(**kwargs) + instance.pk = await self.database.execute(expr) + return instance + + class ModelMetaclass(type): def __new__( mcs: type, name: str, bases: Any, attrs: dict @@ -52,9 +276,7 @@ class ModelMetaclass(type): attrs['__pkname__'] = pkname if not pkname: - raise ModelDefinitionError( - 'Table has to have a primary key.' - ) + raise ModelDefinitionError('Table has to have a primary key.') # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) @@ -62,9 +284,9 @@ class ModelMetaclass(type): pydantic_model = create_model(name, __config__=config, **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model - attrs['__fields__'] = pydantic_model.__fields__ - attrs['__signature__'] = pydantic_model.__signature__ - attrs['__annotations__'] = pydantic_model.__annotations__ + attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) + attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) + attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) attrs['__model_fields__'] = model_fields @@ -78,9 +300,17 @@ class ModelMetaclass(type): class Model(metaclass=ModelMetaclass): __abstract__ = True - def __init__(self, *args, **kwargs) -> None: + objects = QuerySet() + + def __init__(self, **kwargs) -> None: + self._orm_id = uuid.uuid4().hex + self._orm_saved = False + self._orm_relationship_manager = RelationshipManager(self) + self._orm_observers = [] + if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) def __setattr__(self, key: str, value: Any) -> None: @@ -90,14 +320,19 @@ class Model(metaclass=ModelMetaclass): value = json.dumps(value) except TypeError: # pragma no cover pass + value = self.__model_fields__[key].expand_relationship(value, self) setattr(self.values, key, value) else: super().__setattr__(key, value) def __getattribute__(self, key: str) -> Any: if key != '__fields__' and key in self.__fields__: - item = getattr(self.values, key) - if self.is_conversion_to_json_needed(key) and isinstance(item, str): + if key in self._orm_relationship_manager: + parent_item = self._orm_relationship_manager.get(key) + return parent_item + + item = getattr(self.values, key, None) + if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): try: item = json.loads(item) except TypeError: # pragma no cover @@ -106,6 +341,45 @@ class Model(metaclass=ModelMetaclass): return super().__getattribute__(key) + def __repr__(self): # pragma no cover + return self.values.__repr__() + + # def attach(self, observer: 'Model'): + # if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]): + # self._orm_observers.append(observer) + # + # def detach(self, observer: 'Model'): + # for ind, obs in enumerate(self._orm_observers): + # if obs._orm_id == observer._orm_id: + # del self._orm_observers[ind] + # break + # + def notify(self): + for obs in self._orm_observers: # pragma no cover + obs.orm_update(self) + + def orm_update(self, subject: 'Model') -> None: # pragma no cover + print('should be updated here') + + @classmethod + def from_row(cls, row, select_related: List = None) -> 'Model': + item = {} + select_related = select_related or [] + for related in select_related: + if "__" in related: + first_part, remainder = related.split("__", 1) + model_cls = cls.__model_fields__[first_part].to + item[first_part] = model_cls.from_row(row, select_related=[remainder]) + else: + model_cls = cls.__model_fields__[related].to + item[related] = model_cls.from_row(row) + + for column in cls.__table__.columns: + if column.name not in item: + item[column.name] = row[column] + + return cls(**item) + def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -136,17 +410,20 @@ class Model(metaclass=ModelMetaclass): @classmethod def extract_related_names(cls) -> Set: related_names = set() - # for name, field in cls.__fields__.items(): - # if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): - # related_names.add(name) - # elif field.sub_fields and any( - # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): - # related_names.add(name) + for name, field in cls.__fields__.items(): + if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): + related_names.add(name) + # elif field.sub_fields and any( + # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): + # related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: self_fields = self.extract_own_model_fields() self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} + for field in self.extract_related_names(): + if getattr(self, field) is not None: + self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__) return self_fields async def save(self) -> int: @@ -157,6 +434,7 @@ class Model(metaclass=ModelMetaclass): expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) setattr(self, 'pk', item_id) + self.notify() return item_id async def update(self, **kwargs: Any) -> int: @@ -169,16 +447,19 @@ class Model(metaclass=ModelMetaclass): expr = self.__table__.update().values(**self_fields).where( self.pk_column == getattr(self, self.__pkname__)) result = await self.__database__.execute(expr) + self.notify() return result async def delete(self) -> int: expr = self.__table__.delete() expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) result = await self.__database__.execute(expr) + self.notify() return result async def load(self) -> 'Model': expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) + self.notify() return self diff --git a/orm/relations.py b/orm/relations.py new file mode 100644 index 0000000..1bd70cf --- /dev/null +++ b/orm/relations.py @@ -0,0 +1,45 @@ +from typing import Dict, Union, List + +from sqlalchemy import text + + +class Relationship: + + def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'): + self.fk_side = fk_side + self.child = child + self.parent = parent + self.name = name + + +class RelationshipManager: + + def __init__(self, model: 'Model'): + self._orm_id: str = model._orm_id + self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict() + + def __contains__(self, item): + return item in self._relations + + def add_related(self, relation: Relationship): + if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id: + new_relation = Relationship(name=relation.parent.__class__.__name__.lower(), + child=relation.parent, + parent=relation.child, + fk_side='parent') + relation.child._orm_relationship_manager.add(new_relation) + + def add(self, relation: Relationship): + if relation.name in self._relations: + self._relations[relation.name].append(relation) + else: + self._relations[relation.name] = [relation] + self.add_related(relation) + + def get(self, name: str): + for rel, relations in self._relations.items(): + if rel == name: + if relations and relations[0].fk_side == 'parent': + return relations[0].child + else: + return [rela.child for rela in relations] diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py new file mode 100644 index 0000000..d3d95bf --- /dev/null +++ b/tests/test_foreign_keys.py @@ -0,0 +1,231 @@ +import databases +import pytest +import sqlalchemy + +import orm +from orm.exceptions import NoMatch, MultipleMatches +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Album(orm.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Track(orm.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album) + title = orm.String(length=100) + position = orm.Integer() + + +class Organisation(orm.Model): + __tablename__ = "org" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + ident = orm.String(length=100) + + +class Team(orm.Model): + __tablename__ = "team" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + org = orm.ForeignKey(Organisation) + name = orm.String(length=100) + + +class Member(orm.Model): + __tablename__ = "member" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + team = orm.ForeignKey(Team) + email = orm.String(length=100) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + track = await Track.objects.get(title="The Bird") + assert track.album.pk == album.pk + assert track.album.name is None + await track.album.load() + assert track.album.name == "Malibu" + + assert len(album.tracks) == 3 + assert album.tracks[1].title == "Heart don't stand a chance" + + album1 = await Album.objects.get(name='Malibu') + assert album1.pk == 1 + assert album1.tracks is None + + +@pytest.mark.asyncio +async def test_select_related(): + async with database: + album = Album(name="Malibu") + await album.save() + track1 = Track(album=album, title="The Bird", position=1) + track2 = Track(album=album, title="Heart don't stand a chance", position=2) + track3 = Track(album=album, title="The Waters", position=3) + await track1.save() + await track2.save() + await track3.save() + + fantasies = Album(name="Fantasies") + await fantasies.save() + track4 = Track(album=fantasies, title="Help I'm Alive", position=1) + track5 = Track(album=fantasies, title="Sick Muse", position=2) + track6 = Track(album=fantasies, title="Satellite Mind", position=3) + await track4.save() + await track5.save() + await track6.save() + + track = await Track.objects.select_related("album").get(title="The Bird") + assert track.album.name == "Malibu" + + tracks = await Track.objects.select_related("album").all() + assert len(tracks) == 6 + + +@pytest.mark.asyncio +async def test_fk_filter(): + async with database: + malibu = Album(name="Malibu%") + await malibu.save() + await Track.objects.create(album=malibu, title="The Bird", position=1) + await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) + await Track.objects.create(album=malibu, title="The Waters", position=3) + + fantasies = await Album.objects.create(name="Fantasies") + await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) + await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) + + tracks = await Track.objects.select_related("album").filter(album__name="Fantasies").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.select_related("album").filter(album__name__icontains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="fan").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Fantasies" + + tracks = await Track.objects.filter(album__name__contains="Malibu%").all() + assert len(tracks) == 3 + + tracks = await Track.objects.filter(album=malibu).select_related("album").all() + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + tracks = await Track.objects.select_related("album").all(album=malibu) + assert len(tracks) == 3 + for track in tracks: + assert track.album.name == "Malibu%" + + +@pytest.mark.asyncio +async def test_multiple_fk(): + async with database: + acme = await Organisation.objects.create(ident="ACME Ltd") + red_team = await Team.objects.create(org=acme, name="Red Team") + blue_team = await Team.objects.create(org=acme, name="Blue Team") + await Member.objects.create(team=red_team, email="a@example.org") + await Member.objects.create(team=red_team, email="b@example.org") + await Member.objects.create(team=blue_team, email="c@example.org") + await Member.objects.create(team=blue_team, email="d@example.org") + + other = await Organisation.objects.create(ident="Other ltd") + team = await Team.objects.create(org=other, name="Green Team") + await Member.objects.create(team=team, email="e@example.org") + + members = await Member.objects.select_related('team__org').filter(team__org__ident="ACME Ltd").all() + assert len(members) == 4 + for member in members: + assert member.team.org.ident == "ACME Ltd" + + +@pytest.mark.asyncio +async def test_pk_filter(): + async with database: + fantasies = await Album.objects.create(name="Test") + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + tracks = await Track.objects.select_related("album").filter(pk=1).all() + assert len(tracks) == 1 + + tracks = await Track.objects.select_related("album").filter(position=2, album__name='Test').all() + assert len(tracks) == 1 + + +@pytest.mark.asyncio +async def test_limit_and_offset(): + async with database: + fantasies = await Album.objects.create(name="Limitless") + await Track.objects.create(id=None, album=fantasies, title="Sample", position=1) + await Track.objects.create(album=fantasies, title="Sample2", position=2) + await Track.objects.create(album=fantasies, title="Sample3", position=3) + + tracks = await Track.objects.limit(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample" + + tracks = await Track.objects.limit(1).offset(1).all() + assert len(tracks) == 1 + assert tracks[0].title == "Sample2" + + +@pytest.mark.asyncio +async def test_get_exceptions(): + async with database: + fantasies = await Album.objects.create(name="Test") + + with pytest.raises(NoMatch): + await Album.objects.get(name="Test2") + + await Track.objects.create(album=fantasies, title="Test1", position=1) + await Track.objects.create(album=fantasies, title="Test2", position=2) + await Track.objects.create(album=fantasies, title="Test3", position=3) + with pytest.raises(MultipleMatches): + await Track.objects.select_related("album").get(album=fantasies) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..e69de29 From eb99f28431e0c86fc277b200c6731676777662bf Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 4 Aug 2020 21:37:25 +0200 Subject: [PATCH 20/62] added hack to pass as pydantic model in fastapi, tests for fastapi --- .coverage | Bin 53248 -> 53248 bytes orm/__init__.py | 1 + orm/helpers.py | 3 +- orm/models.py | 41 ++++++-- orm/relations.py | 5 +- requirements.txt | 3 +- tests/test_fastapi_usage.py | 42 ++++++++ tests/test_models.py | 200 ++++++++++++++++++++++++++++++++++++ 8 files changed, 282 insertions(+), 13 deletions(-) create mode 100644 tests/test_fastapi_usage.py diff --git a/.coverage b/.coverage index ba7b5f4a6b067d03e366db7864d285fedeb1619e..997d62288b92e3b8a186eb775b7686e374522c00 100644 GIT binary patch delta 667 zcmZozz}&Ead4rKYhmoO`v6+>r@n$RiQv!0lysH`bukbJ7ui_8lSL6H5cbRWFUm>3j zA2aV2-qo801v+_I)p?m2LMQo4P4eSr;bCS-oa8SydA~pRWCK4g7H(#SQUsq>I6fsc zEi*5(Br`uxub|SHi$h7stxH4OaK{ObI{yleQb@crSx%C~}lDPIwvEguW-RiNv-cqc!K)nMh-V@_m* z*<;|x$;xBFT#6)s?k;W<=3KZm$X#5f%=s_@40myAGe<)uf$oa4QIz4~`^3Qilm8X} z1OEN|TlrV;C-Fz|`|@k@3-f&fy6O_2ojeZXO{N%_UuYrKv6fY$qxOp_1v65wgl(5t(C=} z*x5LM96k;%7j{-g&PERQfA98uw*R($_wU=em+yYeJ8twX?|X6mz2Bdozx#G~xlQ`H z^Y84<|NCD1Zuh!+8-u~m{Ofn#yeoU}UjM(ke!o078^~xjulafP+vPb}fgB;$w6rvN z79h)kC2w>6-M8;1pXrxozV)Af@~eI+86N(14E(?Ozwp21f5v~G{|5g>{!{#i`S$=l zzYgf(K7MwPiHtlvo43yQSCHf7-w#s$jsGM68~*3~kNEEbwOr;u$A6Ol2>*VdhPnK# N%)nH|JNd|Y8vw&}><$0` delta 464 zcmZozz}&Ead4rKYhoOa)iLsTD`DQEqQvxzPybl@pukbJ7ui_8lSL6H5cbRWFUlE@r z9~bY#&4L04d0EwYm>EJRJNk)DKIY5M!p+Q(2;?YE{^iTX!o|!`3g&3~i7;|bw(~P) z)#6}ghz5$uY(C*%WuiA%r6Zz{V$&?8#fChCm%PH2p5pn;wqc|fBw%opYK0= zW@&G4&#uG?6m{dAY|t+o!Nvhn$-(8w&dSKy$i!Y(e($IJ@7s6(zTNinUB&if$?EOD zciKNVKmT_5{7-W%?RWpb^Cx@LzVDIugBcS(-CuY2?z_J)-q*i>|DBne4QP}Eo7eoj z`tACYxAq$`zxvNU`BlFZ(79U~_9W4gQP#r}z)^@8RDHbnbNi&0FXD pE6DKh9|0-)#{ZH34gYigNBnnz3NQ1Y<3GuN1gLBY|Kua*Z2 bool: - # expr = self.build_select_expression() - # expr = sqlalchemy.exists(expr).select() - # return await self.database.fetch_val(expr) + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) def limit(self, limit_count: int): return self.__class__( @@ -196,6 +202,14 @@ class QuerySet: offset=offset ) + async def first(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + async def get(self, **kwargs): if kwargs: return await self.filter(**kwargs).get() @@ -287,7 +301,6 @@ class ModelMetaclass(type): attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) - attrs['__model_fields__'] = model_fields new_model = super().__new__( # type: ignore @@ -297,7 +310,7 @@ class ModelMetaclass(type): return new_model -class Model(metaclass=ModelMetaclass): +class Model(tuple, metaclass=ModelMetaclass): __abstract__ = True objects = QuerySet() @@ -338,9 +351,11 @@ class Model(metaclass=ModelMetaclass): except TypeError: # pragma no cover pass return item - return super().__getattribute__(key) + def __eq__(self, other): + return self.values.dict() == other.values.dict() + def __repr__(self): # pragma no cover return self.values.__repr__() @@ -380,6 +395,18 @@ class Model(metaclass=ModelMetaclass): return cls(**item) + @classmethod + def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover + return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value) + + @classmethod + def __get_validators__(cls): # pragma no cover + yield cls.__pydantic_model__.validate + + @classmethod + def schema(cls, by_alias: bool = True): # pragma no cover + return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias) + def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json diff --git a/orm/relations.py b/orm/relations.py index 1bd70cf..8e5f3dd 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -1,6 +1,6 @@ from typing import Dict, Union, List -from sqlalchemy import text +from orm.exceptions import RelationshipNotFound class Relationship: @@ -41,5 +41,4 @@ class RelationshipManager: if rel == name: if relations and relations[0].fk_side == 'parent': return relations[0].child - else: - return [rela.child for rela in relations] + return [rela.child for rela in relations] diff --git a/requirements.txt b/requirements.txt index f6280a7..db5ae51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ sqlalchemy pytest pytest-cov codecov -pytest-asyncio \ No newline at end of file +pytest-asyncio +fastapi \ No newline at end of file diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py new file mode 100644 index 0000000..ae647a0 --- /dev/null +++ b/tests/test_fastapi_usage.py @@ -0,0 +1,42 @@ +import json +from typing import Optional + +import databases +import pydantic +import sqlalchemy +from fastapi import FastAPI +from fastapi.testclient import TestClient + +app = FastAPI() + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class Item(orm.Model): + __tablename__ = "users" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + return item + + +client = TestClient(app) + + +def test_read_main(): + response = client.post("/items/", json={'name': 'test', 'id': 1}) + print(response.json()) + assert response.status_code == 200 + assert response.json() == {'name': 'test', 'id': 1} + item = Item(**response.json()) + assert item.id == 1 diff --git a/tests/test_models.py b/tests/test_models.py index e69de29..cf12c3e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -0,0 +1,200 @@ +import databases +import pytest +import sqlalchemy + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class User(orm.Model): + __tablename__ = "users" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Product(orm.Model): + __tablename__ = "product" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + rating = orm.Integer(minimum=1, maximum=5) + in_stock = orm.Boolean(default=False) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +def test_model_class(): + assert list(User.__model_fields__.keys()) == ["id", "name"] + assert isinstance(User.__model_fields__["id"], orm.Integer) + assert User.__model_fields__["id"].primary_key is True + assert isinstance(User.__model_fields__["name"], orm.String) + assert User.__model_fields__["name"].length == 100 + assert isinstance(User.__table__, sqlalchemy.Table) + + +def test_model_pk(): + user = User(pk=1) + assert user.pk == 1 + assert user.id == 1 + + +@pytest.mark.asyncio +async def test_model_crud(): + async with database: + users = await User.objects.all() + assert users == [] + + user = await User.objects.create(name="Tom") + users = await User.objects.all() + assert user.name == "Tom" + assert user.pk is not None + assert users == [user] + + lookup = await User.objects.get() + assert lookup == user + + await user.update(name="Jane") + users = await User.objects.all() + assert user.name == "Jane" + assert user.pk is not None + assert users == [user] + + await user.delete() + users = await User.objects.all() + assert users == [] + + +@pytest.mark.asyncio +async def test_model_get(): + async with database: + with pytest.raises(orm.NoMatch): + await User.objects.get() + + user = await User.objects.create(name="Tom") + lookup = await User.objects.get() + assert lookup == user + + user = await User.objects.create(name="Jane") + with pytest.raises(orm.MultipleMatches): + await User.objects.get() + + same_user = await User.objects.get(pk=user.id) + assert same_user.id == user.id + assert same_user.pk == user.pk + + +@pytest.mark.asyncio +async def test_model_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + user = await User.objects.get(name="Lucy") + assert user.name == "Lucy" + + with pytest.raises(orm.NoMatch): + await User.objects.get(name="Jim") + + await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) + await Product.objects.create(name="Dress", rating=4) + await Product.objects.create(name="Coat", rating=3, in_stock=True) + + product = await Product.objects.get(name__iexact="t-shirt", rating=5) + assert product.pk is not None + assert product.name == "T-Shirt" + assert product.rating == 5 + + products = await Product.objects.all(rating__gte=2, in_stock=True) + assert len(products) == 2 + + products = await Product.objects.all(name__icontains="T") + assert len(products) == 2 + + # Test escaping % character from icontains, contains, and iexact + await Product.objects.create(name="100%-Cotton", rating=3) + await Product.objects.create(name="Cotton-100%-Egyptian", rating=3) + await Product.objects.create(name="Cotton-100%", rating=3) + products = Product.objects.filter(name__iexact="100%-cotton") + assert await products.count() == 1 + + products = Product.objects.filter(name__contains="%") + assert await products.count() == 3 + + products = Product.objects.filter(name__icontains="%") + assert await products.count() == 3 + + +@pytest.mark.asyncio +async def test_model_exists(): + async with database: + await User.objects.create(name="Tom") + assert await User.objects.filter(name="Tom").exists() is True + assert await User.objects.filter(name="Jane").exists() is False + + +@pytest.mark.asyncio +async def test_model_count(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert await User.objects.count() == 3 + assert await User.objects.filter(name__icontains="T").count() == 1 + + +@pytest.mark.asyncio +async def test_model_limit(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + await User.objects.create(name="Lucy") + + assert len(await User.objects.limit(2).all()) == 2 + + +@pytest.mark.asyncio +async def test_model_limit_with_filter(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + await User.objects.create(name="Tom") + + assert len(await User.objects.limit(2).filter(name__iexact='Tom').all()) == 2 + + +@pytest.mark.asyncio +async def test_offset(): + async with database: + await User.objects.create(name="Tom") + await User.objects.create(name="Jane") + + users = await User.objects.offset(1).limit(1).all() + assert users[0].name == 'Jane' + + +@pytest.mark.asyncio +async def test_model_first(): + async with database: + tom = await User.objects.create(name="Tom") + jane = await User.objects.create(name="Jane") + + assert await User.objects.first() == tom + assert await User.objects.first(name="Jane") == jane + assert await User.objects.filter(name="Jane").first() == jane + assert await User.objects.filter(name="Lucy").first() is None From 6fa7c65b8d1ec941d186bf7322b22715f89dcd2a Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 4 Aug 2020 21:46:13 +0200 Subject: [PATCH 21/62] fix parsing related models from dictionaries --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 2 ++ orm/models.py | 2 +- tests/test_fastapi_usage.py | 15 ++++++++++++--- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/.coverage b/.coverage index 997d62288b92e3b8a186eb775b7686e374522c00..09d9c48be0bcd1f392132dd5d9f0975ec6f294c3 100644 GIT binary patch delta 115 zcmV-(0F3{DpaX!Q1F$tO1vE7}H8Qh0FU(LH@Bk0_59$x#555ny52p`d4_OaO4<`>1 z4)C)P5YP^j@{U^$3<(4Q2^tCZwR!W-$q$oSk7Xht1Ox#ILIiHJf8Fiw-S_{K&E9M_ Vo6Tmk*=+Wk@9w_-0khzbAV7AKFp>ZO delta 113 zcmV-%0FM8FpaX!Q1F$tO1v54}H8Zn1FU(LH@&FI{59$x#55Et!52z1f4_XgR4=E25 z4)U`R5Yi5l@Qzyz3kd`P2^b0XwRv;$2a{QkWg#B~1OW*^1a7l`-R None: + def __init__(self, *args, **kwargs) -> None: self._orm_id = uuid.uuid4().hex self._orm_saved = False self._orm_relationship_manager = RelationshipManager(self) diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index ae647a0..67a4c5f 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -16,6 +16,15 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class Category(orm.Model): + __tablename__ = "cateries" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + class Item(orm.Model): __tablename__ = "users" __metadata__ = metadata @@ -23,6 +32,7 @@ class Item(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) + category = orm.ForeignKey(Category, nullable=True) @app.post("/items/", response_model=Item) @@ -34,9 +44,8 @@ client = TestClient(app) def test_read_main(): - response = client.post("/items/", json={'name': 'test', 'id': 1}) - print(response.json()) + response = client.post("/items/", json={'name': 'test', 'id': 1, 'category': {'name': 'test cat'}}) assert response.status_code == 200 - assert response.json() == {'name': 'test', 'id': 1} + assert response.json() == {'category': {'id': None, 'name': 'test cat'}, 'id': 1, 'name': 'test'} item = Item(**response.json()) assert item.id == 1 From a371c489596033043989e5b1a21c85190600f976 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 4 Aug 2020 21:48:37 +0200 Subject: [PATCH 22/62] switch hack to list instead of tuple --- .coverage | Bin 53248 -> 53248 bytes orm/models.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/.coverage b/.coverage index 09d9c48be0bcd1f392132dd5d9f0975ec6f294c3..e2e4542df14cf7fb1474b518ee080ef920a875a1 100644 GIT binary patch delta 19 acmZozz}&Ead4rigi-nbeZ~V(Pk_CV+#O5p$51B diff --git a/orm/models.py b/orm/models.py index 05bf654..ecc0937 100644 --- a/orm/models.py +++ b/orm/models.py @@ -310,7 +310,7 @@ class ModelMetaclass(type): return new_model -class Model(tuple, metaclass=ModelMetaclass): +class Model(list, metaclass=ModelMetaclass): __abstract__ = True objects = QuerySet() From 475dafb6c9776c8e01cd8d2b42a31c09bc57a602 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 5 Aug 2020 18:32:13 +0200 Subject: [PATCH 23/62] rebuild the registry of relationships --- .coverage | Bin 53248 -> 53248 bytes orm/exceptions.py | 2 +- orm/fields.py | 45 ++++++++++++++----- orm/models.py | 72 +++++++++++++++++++----------- orm/relations.py | 72 +++++++++++++++--------------- tests/test_fastapi_usage.py | 8 +--- tests/test_foreign_keys.py | 9 +++- tests/test_same_table_joins.py | 77 +++++++++++++++++++++++++++++++++ 8 files changed, 204 insertions(+), 81 deletions(-) create mode 100644 tests/test_same_table_joins.py diff --git a/.coverage b/.coverage index e2e4542df14cf7fb1474b518ee080ef920a875a1..8cfa156b16a6247a46fef9e8643a14aac1b27338 100644 GIT binary patch delta 590 zcmZvaUr19?9LMkN?A~q4Ip5jCh#usM<;n<+m|L?y41#=7p?TDCb?vkq3fZS z9`xaY4^5a4#d-+IE)iL=hee>UN}N#0b|pxmbSowsvF18kuf6>~zjMCd4+l6uI^FKF+cj6)ibY{`Oh5{ca0i#s!uHu3i?N&REF&~UBXmq_u=G&L#)#0T zx_hkZ`)(#VC=tdvH6b{7V`$*k;MI|#f#K2X0qs*a5vAisB6w?;t-ThhMJ9r$Mp<6e zDiNW-M(*AOmj?w*zy=QuvIs{G+N-K0*3xq`3zYN+V?(0)FE=Ht`l3hT%gWy3h>DcGw0NzG4la2GOk0 zKz=JZk;&(CPW9)sLvF6)+b;FwOyktudNZF7nKKU^xw${Gi32m$)f=lzsB?BP^g3F~XQCHEFLUR=oe55MjP6IS&FclPu1VqL zsm^3F8z@fh730NXNi<2k^`PYU`@kO=<-RE>Xh{{=Kb(FFhi delta 473 zcmX|+Ur1A77>Cd4?412+-}jqcNQBH0v87%(MY4t^izq0jP!tR?@n+rF!h|4qaqO}n zg}zG(g^C5r2u|ij;w~11g0tAM5<4bcDAYVEi!)^5>1!9>{hpVH2i}oMX=GA*6K?jj z1zKAoZD%9lb6TW%(dQ>*N`Q`+7{+x3X@wSOnhw$qDiNJbX*IsbL=F*2xXGejwd`q( zGCgJ=2|Ay32jvDu65`dn+2xZPWl89%QnWkAyVC+v2;l~qraJwL75W;l=zV&gI>-v| zJ|S$V_w)knC)f;^JXTX_9)Xq(%F3fjTeU6|j$A*uXsAB8`5;agE1o=q^uv zqO3UDuAKBZME7OZFy>diZaz=At(k|V`~nmgC+hNCa>}Cnk;{2t&Xqc?+4q&Ty9Hw} z`LWeXmWR~!N@>fqjC9AH{Ic2EZpXfqFZ@ou6TzO zsm|8BWDdvVL?RJz@O9CV)Kp`;)-=jCTUsADGrpA2RtK!{9drIbWE*3r>b%GZsNfH_ zv56n}hR;|;4s)16hW|1OEi-W|;3vrb*DUOzh%Nkt!PNrtSilElx%eC`yBz%o>oc(- diff --git a/orm/exceptions.py b/orm/exceptions.py index 1a8c6d0..cb2100e 100644 --- a/orm/exceptions.py +++ b/orm/exceptions.py @@ -18,5 +18,5 @@ class MultipleMatches(AsyncOrmException): pass -class RelationshipNotFound(AsyncOrmException): +class RelationshipInstanceError(AsyncOrmException): pass diff --git a/orm/fields.py b/orm/fields.py index 5a28ecf..58fb4da 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -2,11 +2,11 @@ import datetime import decimal from typing import Optional, List -import pydantic import sqlalchemy +from pydantic import Json +from pydantic.fields import ModelField -from orm.exceptions import ModelDefinitionError -from orm.relations import Relationship +from orm.exceptions import ModelDefinitionError, RelationshipInstanceError class BaseField: @@ -79,7 +79,7 @@ class BaseField: def get_constraints(self) -> Optional[List]: return [] - def expand_relationship(self, value, parent): + def expand_relationship(self, value, child): return value @@ -145,7 +145,7 @@ class Time(BaseField): class JSON(BaseField): - __type__ = pydantic.Json + __type__ = Json def get_column_type(self): return sqlalchemy.JSON() @@ -173,8 +173,9 @@ class Decimal(BaseField): class ForeignKey(BaseField): - def __init__(self, to, related_name: str = None, nullable: bool = False): + def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False): super().__init__(nullable=nullable) + self.virtual = virtual self.related_name = related_name self.to = to @@ -191,6 +192,9 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship(self, value, child): + if not isinstance(value, (self.to, dict, int, str)): + raise RelationshipInstanceError( + 'Relationship model can be build only from orm.Model, dict and integer or string (pk).') if isinstance(value, self.to): model = value elif isinstance(value, dict): @@ -199,10 +203,27 @@ class ForeignKey(BaseField): model = self.to(**{self.to.__pkname__: value}) child_model_name = self.related_name or child.__class__.__name__.lower() + 's' - model._orm_relationship_manager.add( - Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) - model.__fields__[child_model_name] = pydantic.fields.ModelField(name=child_model_name, - type_=child.__pydantic_model__, - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__) + model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), + child.__class__.__name__.lower(), + model, child, virtual=self.virtual) + + if child_model_name not in model.__fields__: + model.__fields__[child_model_name] = ModelField(name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__) + model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) + return model + + # def register_relationship(self): + # child_model_name = self.related_name or child.__class__.__name__.lower() + 's' + # if not child_model_name in model._orm_relationship_manager: + # model._orm_relationship_manager.add( + # Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) + # model.__fields__[child_model_name] = ModelField(name=child_model_name, + # type_=Optional[child.__pydantic_model__], + # model_config=child.__pydantic_model__.__config__, + # class_validators=child.__pydantic_model__.__validators__) + # model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) + # breakpoint() diff --git a/orm/models.py b/orm/models.py index ecc0937..0d93ee1 100644 --- a/orm/models.py +++ b/orm/models.py @@ -2,18 +2,19 @@ import copy import inspect import json import uuid -from abc import ABCMeta -from typing import Any, List, Type +from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar from typing import Set, Dict import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model +from pydantic import BaseModel, BaseConfig, create_model -from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch -from orm.fields import BaseField +from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches +from orm.fields import BaseField, ForeignKey from orm.relations import RelationshipManager +relationship_manager = RelationshipManager() + def parse_pydantic_field_from_model_fields(object_dict: dict): pydantic_fields = {field_name: ( @@ -25,6 +26,24 @@ def parse_pydantic_field_from_model_fields(object_dict: dict): return pydantic_fields +def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict): + pkname = None + columns: List[sqlalchemy.Column] = [] + model_fields: Dict[str, BaseField] = {} + + for field_name, field in object_dict.items(): + if isinstance(field, BaseField): + model_fields[field_name] = field + if not field.pydantic_only: + if field.primary_key: + pkname = field_name + if isinstance(field, ForeignKey): + reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's' + relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name) + columns.append(field.get_column(field_name)) + return pkname, columns, model_fields + + FILTER_OPERATORS = { "exact": "__eq__", "iexact": "ilike", @@ -272,19 +291,9 @@ class ModelMetaclass(type): tablename = attrs["__tablename__"] metadata = attrs["__metadata__"] - pkname = None - - columns = [] - model_fields = {} - for field_name, field in attrs.items(): - if isinstance(field, BaseField): - model_fields[field_name] = field - if not field.pydantic_only: - if field.primary_key: - pkname = field_name - columns.append(field.get_column(field_name)) # sqlalchemy table creation + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs) attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname @@ -311,18 +320,28 @@ class ModelMetaclass(type): class Model(list, metaclass=ModelMetaclass): + # Model inherits from list in order to be treated as request.Body parameter in fastapi routes, + # inheriting from pydantic.BaseModel causes metaclass conflicts __abstract__ = True + if TYPE_CHECKING: # pragma no cover + __model_fields__: Dict[str, TypeVar[BaseField]] + __table__: sqlalchemy.Table + __fields__: Dict[str, pydantic.fields.ModelField] + __pydantic_model__: Type[BaseModel] + __pkname__: str objects = QuerySet() def __init__(self, *args, **kwargs) -> None: - self._orm_id = uuid.uuid4().hex - self._orm_saved = False - self._orm_relationship_manager = RelationshipManager(self) - self._orm_observers = [] + self._orm_id: str = uuid.uuid4().hex + self._orm_saved: bool = False + self._orm_relationship_manager: RelationshipManager = relationship_manager + self._orm_observers: List['Model'] = [] + self.values: Optional[BaseModel] = None if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") + # breakpoint() kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) @@ -340,9 +359,9 @@ class Model(list, metaclass=ModelMetaclass): def __getattribute__(self, key: str) -> Any: if key != '__fields__' and key in self.__fields__: - if key in self._orm_relationship_manager: - parent_item = self._orm_relationship_manager.get(key) - return parent_item + relation_key = self.__class__.__name__.title() + '_' + key + if self._orm_relationship_manager.contains(relation_key, self): + return self._orm_relationship_manager.get(relation_key, self) item = getattr(self.values, key, None) if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): @@ -393,11 +412,12 @@ class Model(list, metaclass=ModelMetaclass): if column.name not in item: item[column.name] = row[column] + # breakpoint() return cls(**item) @classmethod - def validate(cls: Type['Model'], value: Any) -> 'Model': # pragma no cover - return cls.__pydantic_model__.validate(cls.__pydantic_model__.__class__, value) + def validate(cls, value: Any) -> 'BaseModel': # pragma no cover + return cls.__pydantic_model__.validate(value=value) @classmethod def __get_validators__(cls): # pragma no cover @@ -405,7 +425,7 @@ class Model(list, metaclass=ModelMetaclass): @classmethod def schema(cls, by_alias: bool = True): # pragma no cover - return cls.__pydantic_model__.schame(cls.__pydantic_model__, by_alias=by_alias) + return cls.__pydantic_model__.schema(by_alias=by_alias) def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json diff --git a/orm/relations.py b/orm/relations.py index 8e5f3dd..583f158 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -1,44 +1,46 @@ -from typing import Dict, Union, List +from typing import TYPE_CHECKING -from orm.exceptions import RelationshipNotFound - - -class Relationship: - - def __init__(self, name: str, parent: 'Model', child: 'Model', fk_side: str = 'child'): - self.fk_side = fk_side - self.child = child - self.parent = parent - self.name = name +if TYPE_CHECKING: # pragma no cover + from orm.models import Model class RelationshipManager: - def __init__(self, model: 'Model'): - self._orm_id: str = model._orm_id - self._relations: Dict[str, Union[Relationship, List[Relationship]]] = dict() + def __init__(self): + self._relations = dict() - def __contains__(self, item): - return item in self._relations + def add_relation_type(self, relations_key, reverse_key): + print(relations_key, reverse_key) + if relations_key not in self._relations: + self._relations[relations_key] = {'type': 'primary'} + if reverse_key not in self._relations: + self._relations[reverse_key] = {'type': 'reverse'} - def add_related(self, relation: Relationship): - if relation.fk_side == 'child' and relation.parent._orm_id == self._orm_id: - new_relation = Relationship(name=relation.parent.__class__.__name__.lower(), - child=relation.parent, - parent=relation.child, - fk_side='parent') - relation.child._orm_relationship_manager.add(new_relation) + def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): + parent_id = parent._orm_id + child_id = child._orm_id + if virtual: + child_name, parent_name = parent_name, child_name + child_id, parent_id = parent_id, child_id + child, parent = parent, child + self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append( + child) + self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) - def add(self, relation: Relationship): - if relation.name in self._relations: - self._relations[relation.name].append(relation) - else: - self._relations[relation.name] = [relation] - self.add_related(relation) + def contains(self, relations_key: str, object: 'Model'): + if relations_key in self._relations: + return object._orm_id in self._relations[relations_key] + return False - def get(self, name: str): - for rel, relations in self._relations.items(): - if rel == name: - if relations and relations[0].fk_side == 'parent': - return relations[0].child - return [rela.child for rela in relations] + def get(self, relations_key: str, object: 'Model'): + if relations_key in self._relations: + if object._orm_id in self._relations[relations_key]: + if self._relations[relations_key]['type'] == 'primary': + return self._relations[relations_key][object._orm_id][0] + return self._relations[relations_key][object._orm_id] + + def __str__(self): # pragma no cover + return ''.join(self._relations[rel].__str__() for rel in self._relations) + + def __repr__(self): # pragma no cover + return self.__str__() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 67a4c5f..00c0672 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -1,17 +1,13 @@ -import json -from typing import Optional - import databases -import pydantic import sqlalchemy from fastapi import FastAPI from fastapi.testclient import TestClient -app = FastAPI() - import orm from tests.settings import DATABASE_URL +app = FastAPI() + database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index d3d95bf..c222cfa 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -3,7 +3,7 @@ import pytest import sqlalchemy import orm -from orm.exceptions import NoMatch, MultipleMatches +from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -229,3 +229,10 @@ async def test_get_exceptions(): await Track.objects.create(album=fantasies, title="Test3", position=3) with pytest.raises(MultipleMatches): await Track.objects.select_related("album").get(album=fantasies) + + +@pytest.mark.asyncio +async def test_wrong_model_passed_as_fk(): + with pytest.raises(RelationshipInstanceError): + org = await Organisation.objects.create(ident="ACME Ltd") + await Track.objects.create(album=org, title="Test1", position=1) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py new file mode 100644 index 0000000..cadbfc8 --- /dev/null +++ b/tests/test_same_table_joins.py @@ -0,0 +1,77 @@ +import databases +import pytest +import sqlalchemy + +import orm +from tests.settings import DATABASE_URL + +database = databases.Database(DATABASE_URL, force_rollback=True) +metadata = sqlalchemy.MetaData() + + +class SchoolClass(orm.Model): + __tablename__ = "schoolclasses" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Category(orm.Model): + __tablename__ = "cateogories" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Student(orm.Model): + __tablename__ = "students" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + schoolclass = orm.ForeignKey(SchoolClass) + category = orm.ForeignKey(Category, nullable=True) + + +class Teacher(orm.Model): + __tablename__ = "teachers" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + schoolclass = orm.ForeignKey(SchoolClass) + category = orm.ForeignKey(Category, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(): + async with database: + class1 = await SchoolClass.objects.create(name="Math") + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + + classes = await SchoolClass.objects.select_related(['teachers', 'students']).all() + assert classes[0].name == 'Math' + assert classes[0].students[0].name == 'Jane' + + # related fields of main model are only populated by pk + # but you can load them anytime + assert classes[0].students[0].schoolclass.name is None + await classes[0].students[0].schoolclass.load() + assert classes[0].students[0].schoolclass.name == 'Math' From 6efb56a2a0c63afcc2e610bf3153e5874f133320 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 7 Aug 2020 05:37:10 +0200 Subject: [PATCH 24/62] changed relationshipt to wekrefs --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 12 -- orm/models.py | 300 +++------------------------------ orm/queryset.py | 242 ++++++++++++++++++++++++++ orm/relations.py | 44 ++++- tests/test_fastapi_usage.py | 4 +- tests/test_same_table_joins.py | 18 +- 7 files changed, 324 insertions(+), 296 deletions(-) create mode 100644 orm/queryset.py diff --git a/.coverage b/.coverage index 8cfa156b16a6247a46fef9e8643a14aac1b27338..ad89c9e7f4286dd936db7dc6b3d5e305f71bfba6 100644 GIT binary patch delta 702 zcmZ9KUr19?9LMkN-0R)F+xh+GpCrL@{vrPaDmc-ET)o&#yjT(dwa z2p{f4k)nGjB!viuiJ7D5B@7#ZLM5pL?IGAkEs!;Boom!XPv?BTzwbGG4xBL`8S{}R z^bG+Bv5 zFO-YMO5)Uax+sduG18_}wwYNBUg_)6kIE&aRS*9d8;VKGKK*BWQO>mh1UZc3SKP*B z^y6*p#Aa;3TFgQMpCQP6=iqEA2)tYXeX-E0do%VeTAe(SLs@i^vgm}(I~Cm+4lKns zB96dn(CLYPnTYZkCRQjhUB8>0j!(w6LZRa}n@!3$vea(WO2%`%3fmr1Ht$zOCwEh1 z1|TbFV5`O8IT&k5jCyxGa)HExGt$*~PGaNJUifJ^a%16gmU|(wXp8Q}UBih;^HhaY zn8I33DW0A0Z==b@-QLaJiB4ZM><{$3@h)v&8eVXFy~*XR9o`_aVw%|9-JPoi#_KGA zJ`NN31H%}?Eqssfa0OrCbDYP=ID5L7dm57l#kvJR6RqNZ6O7 I+OuHYZ`F+ILI3~& delta 614 zcmY+9Ur1A77{<@x`_BHH?fc$oMpCI0Nkc&mTh2eVE~4z>9H9`}g+`!5BRAW!VZx>t z-E^VQ2MRejyC~L$RCb8S41>BbMIxFIh@Ng*6r~pvhpc9wi!Qo*p5OcM@Vp~&G!jQo z>g!~e^O(E2p}yJW(wplR9ZpJu0wm!++=2^GNptigP0;JKl_HWPQzWPtI8-84Q6zM* zX!}9v?YTD46AaZ33^fuY46?^cIh$09L_3>SYO(Dvs@!+EyT{k(@4L|#@bhScjD-Gt z)U7(O6$yd7=~oLx3yZ2X*o*}4UOih-?Rri%iF{W`fMKYDizGyMX%bfGLwG=g^d!Y( zg~a%-cEZkg9>vs+2H9v=3A?KlV^WK*v*2$Kx?q$egyx%sE~bwQ^Z( zUBB138T)0ja6>J!&AbQAceb-e_Q*MtBw9zQDH0y~`dizbnO|Eo{8Jla;g*S+@JlV5 zUe`R~=bNX#G=(B30=cQ~xUuaubZO*Wv}JjD!=H(4W#%)PtZ0#VYoX-x`N}a@XYj_; z=o(uXZ!IN|5FiU%@DoyCz$f?si|_{K;2AuE1po3ju(&4a)7m8`qA<(MuS+>o+&)KdTe diff --git a/orm/fields.py b/orm/fields.py index 58fb4da..6f9c221 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -215,15 +215,3 @@ class ForeignKey(BaseField): model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) return model - - # def register_relationship(self): - # child_model_name = self.related_name or child.__class__.__name__.lower() + 's' - # if not child_model_name in model._orm_relationship_manager: - # model._orm_relationship_manager.add( - # Relationship(name=child_model_name, child=child, parent=model, fk_side='child')) - # model.__fields__[child_model_name] = ModelField(name=child_model_name, - # type_=Optional[child.__pydantic_model__], - # model_config=child.__pydantic_model__.__config__, - # class_validators=child.__pydantic_model__.__validators__) - # model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) - # breakpoint() diff --git a/orm/models.py b/orm/models.py index 0d93ee1..508ba00 100644 --- a/orm/models.py +++ b/orm/models.py @@ -2,21 +2,22 @@ import copy import inspect import json import uuid -from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar +from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple from typing import Set, Dict import pydantic import sqlalchemy from pydantic import BaseModel, BaseConfig, create_model -from orm.exceptions import ModelDefinitionError, NoMatch, MultipleMatches +import orm.queryset as qry +from orm.exceptions import ModelDefinitionError from orm.fields import BaseField, ForeignKey from orm.relations import RelationshipManager relationship_manager = RelationshipManager() -def parse_pydantic_field_from_model_fields(object_dict: dict): +def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: pydantic_fields = {field_name: ( base_field.__type__, ... if base_field.is_required else base_field.default_value @@ -26,8 +27,10 @@ def parse_pydantic_field_from_model_fields(object_dict: dict): return pydantic_fields -def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict): - pkname = None +def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str], + List[sqlalchemy.Column], + Dict[str, BaseField]]: + pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] model_fields: Dict[str, BaseField] = {} @@ -39,243 +42,17 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict): pkname = field_name if isinstance(field, ForeignKey): reverse_name = field.related_name or field.to.__name__.title() + '_' + name.lower() + 's' - relationship_manager.add_relation_type(name + '_' + field.to.__name__.lower(), reverse_name) + relation_name = name + '_' + field.to.__name__.lower() + relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename) columns.append(field.get_column(field_name)) return pkname, columns, model_fields -FILTER_OPERATORS = { - "exact": "__eq__", - "iexact": "ilike", - "contains": "like", - "icontains": "ilike", - "in": "in_", - "gt": "__gt__", - "gte": "__ge__", - "lt": "__lt__", - "lte": "__le__", -} +def get_pydantic_base_orm_config() -> Type[BaseConfig]: + class Config(BaseConfig): + orm_mode = True - -class QuerySet: - ESCAPE_CHARACTERS = ['%', '_'] - - def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, - limit_count: int = None, offset: int = None): - self.model_cls = model_cls - self.filter_clauses = [] if filter_clauses is None else filter_clauses - self._select_related = [] if select_related is None else select_related - self.limit_count = limit_count - self.query_offset = offset - - def __get__(self, instance, owner): - return self.__class__(model_cls=owner) - - @property - def database(self): - return self.model_cls.__database__ - - @property - def table(self): - return self.model_cls.__table__ - - def build_select_expression(self): - tables = [self.table] - select_from = self.table - - for item in self._select_related: - model_cls = self.model_cls - select_from = self.table - for part in item.split("__"): - model_cls = model_cls.__model_fields__[part].to - select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) - tables.append(model_cls.__table__) - - expr = sqlalchemy.sql.select(tables) - expr = expr.select_from(select_from) - - if self.filter_clauses: - if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] - else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) - expr = expr.where(clause) - - if self.limit_count: - expr = expr.limit(self.limit_count) - - if self.query_offset: - expr = expr.offset(self.query_offset) - - # print(expr.compile(compile_kwargs={"literal_binds": True})) - return expr - - def filter(self, **kwargs): - filter_clauses = self.filter_clauses - select_related = list(self._select_related) - - if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ - kwargs[pk_name] = kwargs.pop("pk") - - for key, value in kwargs.items(): - if "__" in key: - parts = key.split("__") - - # Determine if we should treat the final part as a - # filter operator or as a related field. - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] - - model_cls = self.model_cls - if related_parts: - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - for part in related_parts: - model_cls = model_cls.__model_fields__[part].to - - column = model_cls.__table__.columns[field_name] - - else: - op = "exact" - column = self.table.columns[key] - - # Map the operation code onto SQLAlchemy's ColumnElement - # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement - op_attr = FILTER_OPERATORS[op] - has_escaped_character = False - - if op in ["contains", "icontains"]: - has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS - if c in value) - if has_escaped_character: - # enable escape modifier - for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f'\\{char}') - value = f"%{value}%" - - if isinstance(value, Model): - value = value.pk - - clause = getattr(column, op_attr)(value) - clause.modifiers['escape'] = '\\' if has_escaped_character else None - filter_clauses.append(clause) - - return self.__class__( - model_cls=self.model_cls, - filter_clauses=filter_clauses, - select_related=select_related, - limit_count=self.limit_count, - offset=self.query_offset - ) - - def select_related(self, related): - if not isinstance(related, (list, tuple)): - related = [related] - - related = list(self._select_related) + related - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=related, - limit_count=self.limit_count, - offset=self.query_offset - ) - - async def exists(self) -> bool: - expr = self.build_select_expression() - expr = sqlalchemy.exists(expr).select() - return await self.database.fetch_val(expr) - - async def count(self) -> int: - expr = self.build_select_expression().alias("subquery_for_count") - expr = sqlalchemy.func.count().select().select_from(expr) - return await self.database.fetch_val(expr) - - def limit(self, limit_count: int): - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=limit_count, - offset=self.query_offset - ) - - def offset(self, offset: int): - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=offset - ) - - async def first(self, **kwargs): - if kwargs: - return await self.filter(**kwargs).first() - - rows = await self.limit(1).all() - if rows: - return rows[0] - - async def get(self, **kwargs): - if kwargs: - return await self.filter(**kwargs).get() - - expr = self.build_select_expression().limit(2) - rows = await self.database.fetch_all(expr) - - if not rows: - raise NoMatch() - if len(rows) > 1: - raise MultipleMatches() - return self.model_cls.from_row(rows[0], select_related=self._select_related) - - async def all(self, **kwargs): - if kwargs: - return await self.filter(**kwargs).all() - - expr = self.build_select_expression() - rows = await self.database.fetch_all(expr) - return [ - self.model_cls.from_row(row, select_related=self._select_related) - for row in rows - ] - - async def create(self, **kwargs): - - new_kwargs = dict(**kwargs) - - # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] - if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): - del new_kwargs[pkname] - - # substitute related models with their pk - for field in self.model_cls.extract_related_names(): - if field in new_kwargs and new_kwargs.get(field) is not None: - new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) - - # Build the insert expression. - expr = self.table.insert() - expr = expr.values(**new_kwargs) - - # Execute the insert, and return a new model instance. - instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) - return instance + return Config class ModelMetaclass(type): @@ -293,7 +70,7 @@ class ModelMetaclass(type): metadata = attrs["__metadata__"] # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs) + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename) attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) attrs['__columns__'] = columns attrs['__pkname__'] = pkname @@ -303,8 +80,7 @@ class ModelMetaclass(type): # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - config = type('Config', (BaseConfig,), {'orm_mode': True}) - pydantic_model = create_model(name, __config__=config, **pydantic_fields) + pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields) attrs['__pydantic_fields__'] = pydantic_fields attrs['__pydantic_model__'] = pydantic_model attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) @@ -330,21 +106,22 @@ class Model(list, metaclass=ModelMetaclass): __pydantic_model__: Type[BaseModel] __pkname__: str - objects = QuerySet() + objects = qry.QuerySet() def __init__(self, *args, **kwargs) -> None: self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False self._orm_relationship_manager: RelationshipManager = relationship_manager - self._orm_observers: List['Model'] = [] self.values: Optional[BaseModel] = None if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") - # breakpoint() kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} self.values = self.__pydantic_model__(**kwargs) + def __del__(self): + self._orm_relationship_manager.deregister(self) + def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: if self.is_conversion_to_json_needed(key) and not isinstance(value, str): @@ -378,23 +155,6 @@ class Model(list, metaclass=ModelMetaclass): def __repr__(self): # pragma no cover return self.values.__repr__() - # def attach(self, observer: 'Model'): - # if all([obs._orm_id != observer._orm_id for obs in self._orm_observers]): - # self._orm_observers.append(observer) - # - # def detach(self, observer: 'Model'): - # for ind, obs in enumerate(self._orm_observers): - # if obs._orm_id == observer._orm_id: - # del self._orm_observers[ind] - # break - # - def notify(self): - for obs in self._orm_observers: # pragma no cover - obs.orm_update(self) - - def orm_update(self, subject: 'Model') -> None: # pragma no cover - print('should be updated here') - @classmethod def from_row(cls, row, select_related: List = None) -> 'Model': item = {} @@ -412,20 +172,19 @@ class Model(list, metaclass=ModelMetaclass): if column.name not in item: item[column.name] = row[column] - # breakpoint() return cls(**item) - @classmethod - def validate(cls, value: Any) -> 'BaseModel': # pragma no cover - return cls.__pydantic_model__.validate(value=value) + # @classmethod + # def validate(cls, value: Any) -> 'BaseModel': # pragma no cover + # return cls.__pydantic_model__.validate(value=value) @classmethod def __get_validators__(cls): # pragma no cover yield cls.__pydantic_model__.validate - @classmethod - def schema(cls, by_alias: bool = True): # pragma no cover - return cls.__pydantic_model__.schema(by_alias=by_alias) + # @classmethod + # def schema(cls, by_alias: bool = True): # pragma no cover + # return cls.__pydantic_model__.schema(by_alias=by_alias) def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -460,9 +219,6 @@ class Model(list, metaclass=ModelMetaclass): for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): related_names.add(name) - # elif field.sub_fields and any( - # [inspect.isclass(f.type_) and issubclass(f.type_, pydantic.BaseModel) for f in field.sub_fields]): - # related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: @@ -481,7 +237,6 @@ class Model(list, metaclass=ModelMetaclass): expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) setattr(self, 'pk', item_id) - self.notify() return item_id async def update(self, **kwargs: Any) -> int: @@ -494,19 +249,16 @@ class Model(list, metaclass=ModelMetaclass): expr = self.__table__.update().values(**self_fields).where( self.pk_column == getattr(self, self.__pkname__)) result = await self.__database__.execute(expr) - self.notify() return result async def delete(self) -> int: expr = self.__table__.delete() expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) result = await self.__database__.execute(expr) - self.notify() return result async def load(self) -> 'Model': expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) - self.notify() return self diff --git a/orm/queryset.py b/orm/queryset.py new file mode 100644 index 0000000..d0d6b2c --- /dev/null +++ b/orm/queryset.py @@ -0,0 +1,242 @@ +from typing import List, TYPE_CHECKING + +import sqlalchemy + +import orm +from orm.exceptions import NoMatch, MultipleMatches + +if TYPE_CHECKING: # pragma no cover + from orm.models import Model + +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} + + +class QuerySet: + ESCAPE_CHARACTERS = ['%', '_'] + + def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None, + limit_count: int = None, offset: int = None): + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + + def __get__(self, instance, owner): + return self.__class__(model_cls=owner) + + @property + def database(self): + return self.model_cls.__database__ + + @property + def table(self): + return self.model_cls.__table__ + + def build_select_expression(self): + tables = [self.table] + select_from = self.table + + for item in self._select_related: + model_cls = self.model_cls + select_from = self.table + for part in item.split("__"): + model_cls = model_cls.__model_fields__[part].to + select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) + tables.append(model_cls.__table__) + + expr = sqlalchemy.sql.select(tables) + expr = expr.select_from(select_from) + + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + print(expr.compile(compile_kwargs={"literal_binds": True})) + return expr + + def filter(self, **kwargs): + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + + if kwargs.get("pk"): + pk_name = self.model_cls.__pkname__ + kwargs[pk_name] = kwargs.pop("pk") + + for key, value in kwargs.items(): + if "__" in key: + parts = key.split("__") + + # Determine if we should treat the final part as a + # filter operator or as a related field. + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + model_cls = self.model_cls + if related_parts: + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + for part in related_parts: + model_cls = model_cls.__model_fields__[part].to + + column = model_cls.__table__.columns[field_name] + + else: + op = "exact" + column = self.table.columns[key] + + # Map the operation code onto SQLAlchemy's ColumnElement + # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement + op_attr = FILTER_OPERATORS[op] + has_escaped_character = False + + if op in ["contains", "icontains"]: + has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS + if c in value) + if has_escaped_character: + # enable escape modifier + for char in self.ESCAPE_CHARACTERS: + value = value.replace(char, f'\\{char}') + value = f"%{value}%" + + if isinstance(value, orm.Model): + value = value.pk + + clause = getattr(column, op_attr)(value) + clause.modifiers['escape'] = '\\' if has_escaped_character else None + filter_clauses.append(clause) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=filter_clauses, + select_related=select_related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + def select_related(self, related): + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset + ) + + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) + + def limit(self, limit_count: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=limit_count, + offset=self.query_offset + ) + + def offset(self, offset: int): + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=offset + ) + + async def first(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + + async def get(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).get() + + expr = self.build_select_expression().limit(2) + rows = await self.database.fetch_all(expr) + + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + return self.model_cls.from_row(rows[0], select_related=self._select_related) + + async def all(self, **kwargs): + if kwargs: + return await self.filter(**kwargs).all() + + expr = self.build_select_expression() + rows = await self.database.fetch_all(expr) + return [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + + async def create(self, **kwargs): + + new_kwargs = dict(**kwargs) + + # Remove primary key when None to prevent not null constraint in postgresql. + pkname = self.model_cls.__pkname__ + pk = self.model_cls.__model_fields__[pkname] + if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): + del new_kwargs[pkname] + + # substitute related models with their pk + for field in self.model_cls.extract_related_names(): + if field in new_kwargs and new_kwargs.get(field) is not None: + new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) + + # Build the insert expression. + expr = self.table.insert() + expr = expr.values(**new_kwargs) + + # Execute the insert, and return a new model instance. + instance = self.model_cls(**kwargs) + instance.pk = await self.database.execute(expr) + return instance diff --git a/orm/relations.py b/orm/relations.py index 583f158..31bb520 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -1,20 +1,49 @@ +import pprint +import string +import uuid +from random import choices from typing import TYPE_CHECKING +from weakref import proxy + +from sqlalchemy import text + +from orm.fields import ForeignKey if TYPE_CHECKING: # pragma no cover from orm.models import Model +def get_table_alias(): + return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] + + +def get_relation_config(relation_type: str, table_name: str, field: ForeignKey): + alias = get_table_alias() + config = {'type': relation_type, + 'table_alias': alias, + 'source_table': table_name if relation_type == 'primary' else field.to.__tablename__, + 'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name + } + return config + + class RelationshipManager: def __init__(self): self._relations = dict() - def add_relation_type(self, relations_key, reverse_key): + def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str): print(relations_key, reverse_key) if relations_key not in self._relations: - self._relations[relations_key] = {'type': 'primary'} + self._relations[relations_key] = get_relation_config('primary', table_name, field) if reverse_key not in self._relations: - self._relations[reverse_key] = {'type': 'reverse'} + self._relations[reverse_key] = get_relation_config('reverse', table_name, field) + + def deregister(self, model: 'Model'): + for rel_type in self._relations.keys(): + if model.__class__.__name__.lower() in rel_type.lower(): + if model._orm_id in self._relations[rel_type]: + del self._relations[rel_type][model._orm_id] def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): parent_id = parent._orm_id @@ -22,9 +51,10 @@ class RelationshipManager: if virtual: child_name, parent_name = parent_name, child_name child_id, parent_id = parent_id, child_id - child, parent = parent, child - self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append( - child) + child, parent = parent, proxy(child) + else: + child = proxy(child) + self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child) self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) def contains(self, relations_key: str, object: 'Model'): @@ -40,7 +70,7 @@ class RelationshipManager: return self._relations[relations_key][object._orm_id] def __str__(self): # pragma no cover - return ''.join(self._relations[rel].__str__() for rel in self._relations) + return pprint.pformat(self._relations, indent=4, width=1) def __repr__(self): # pragma no cover return self.__str__() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 00c0672..8889064 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -13,7 +13,7 @@ metadata = sqlalchemy.MetaData() class Category(orm.Model): - __tablename__ = "cateries" + __tablename__ = "categories" __metadata__ = metadata __database__ = database @@ -22,7 +22,7 @@ class Category(orm.Model): class Item(orm.Model): - __tablename__ = "users" + __tablename__ = "items" __metadata__ = metadata __database__ = database diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index cadbfc8..155e6a4 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -19,7 +19,7 @@ class SchoolClass(orm.Model): class Category(orm.Model): - __tablename__ = "cateogories" + __tablename__ = "categories" __metadata__ = metadata __database__ = database @@ -75,3 +75,19 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert classes[0].students[0].schoolclass.name is None await classes[0].students[0].schoolclass.load() assert classes[0].students[0].schoolclass.name == 'Math' + + +@pytest.mark.asyncio +async def test_right_tables_join(): + async with database: + class1 = await SchoolClass.objects.create(name="Math") + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + + classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() + assert classes[0].name == 'Math' + assert classes[0].students[0].name == 'Jane' + breakpoint() + assert classes[0].teachers[0].category.name == 'Domestic' From 62475a1949ba8aebf634961943116f6aabf1129c Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 7 Aug 2020 13:20:16 +0200 Subject: [PATCH 25/62] change queryset to work with column and table aliases --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 8 ++- orm/models.py | 42 +++++++++++--- orm/queryset.py | 101 ++++++++++++++++++++++++++++++--- orm/relations.py | 30 +++++++++- tests/test_same_table_joins.py | 38 ++++++++++++- 6 files changed, 194 insertions(+), 25 deletions(-) diff --git a/.coverage b/.coverage index ad89c9e7f4286dd936db7dc6b3d5e305f71bfba6..955ea24ea0b5b017d6968d3bd41402d31daa70d6 100644 GIT binary patch delta 280 zcmV+z0q6dJpaX!Q1F!}l3I_lWLJu(y6%Plq5fJwdlMpW=ATcvKF)=zcIS&E@8y9wE za&u{KZZ$44E-`^S0R$a_Ym-thSOhI_WMZ?WFJVv~yZ{gR59$x#54{hw4~Gv|4?_<- z4=xW04)YGu4!pAw5SR`Y6b=Ld2`Ub5^X8p5?|IM1e*qW(fB}=@j&Wxa2m}EMCJ25N z^Go~v9*+0c`+W&62Lu5LUI*IpKlX2TxB0$S+wXSUw!6o>9e?N8yY1`e{=E0+&)x02 zPklCk@|-y;1_S{KRtA3e^3L!7x%_P2{A-c{y|<69@7Z3qe|N#b?(UPSj~g3h|GL}V eyYK%eo4wg=Hk-|6v)Sx7zq|YKdcU*ak03y~Tzbp^ delta 255 zcmV)74`dHh4@(a|4;c>w z4&n~avk?%h4iga$1OW*o4sP@2&3oSS@$WeQ9+T3JaaRur1OW*k2!0jwOZ$C)_kIZ^ z2Lu5LN(XBBAN$+g{q45x?s&Ij8GE<;xj*my`Ez&s?o*yK2PXyu0SQnBe)sau@Bg{{ zY~K8vWT5Zuv%Y7`f4ksdcay4*8ym8J-R None: self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False - self._orm_relationship_manager: RelationshipManager = relationship_manager self.values: Optional[BaseModel] = None if "pk" in kwargs: @@ -129,8 +131,12 @@ class Model(list, metaclass=ModelMetaclass): value = json.dumps(value) except TypeError: # pragma no cover pass + value = self.__model_fields__[key].expand_relationship(value, self) - setattr(self.values, key, value) + + relation_key = self.__class__.__name__.title() + '_' + key + if not self._orm_relationship_manager.contains(relation_key, self): + setattr(self.values, key, value) else: super().__setattr__(key, value) @@ -152,25 +158,36 @@ class Model(list, metaclass=ModelMetaclass): def __eq__(self, other): return self.values.dict() == other.values.dict() + def __same__(self, other): + assert self.__class__ == other.__class__ + return self._orm_id == other._orm_id or ( + self.values is not None and other.values is not None and self.pk == other.pk) + def __repr__(self): # pragma no cover return self.values.__repr__() @classmethod - def from_row(cls, row, select_related: List = None) -> 'Model': + def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model': + item = {} select_related = select_related or [] + + table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name) + previous_table = cls.__table__.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) model_cls = cls.__model_fields__[first_part].to - item[first_part] = model_cls.from_row(row, select_related=[remainder]) + child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table) + item[first_part] = child else: model_cls = cls.__model_fields__[related].to - item[related] = model_cls.from_row(row) + child = model_cls.from_row(row, previous_table=previous_table) + item[related] = child for column in cls.__table__.columns: if column.name not in item: - item[column.name] = row[column] + item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}'] return cls(**item) @@ -202,7 +219,14 @@ class Model(list, metaclass=ModelMetaclass): return self.__table__.primary_key.columns.values()[0] def dict(self) -> Dict: - return self.values.dict() + dict_instance = self.values.dict() + for field in self.extract_related_names(): + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = [x.dict() for x in nested_model] + else: + dict_instance[field] = nested_model.dict() if nested_model is not None else {} + return dict_instance def from_dict(self, value_dict: Dict) -> None: for key, value in value_dict.items(): diff --git a/orm/queryset.py b/orm/queryset.py index d0d6b2c..1f01d7d 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,6 +1,7 @@ -from typing import List, TYPE_CHECKING +from typing import List, TYPE_CHECKING, Type import sqlalchemy +from sqlalchemy import text import orm from orm.exceptions import NoMatch, MultipleMatches @@ -24,13 +25,14 @@ FILTER_OPERATORS = { class QuerySet: ESCAPE_CHARACTERS = ['%', '_'] - def __init__(self, model_cls: 'Model' = None, filter_clauses: List = None, select_related: List = None, + def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, limit_count: int = None, offset: int = None): self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset + self.aliases_dict = dict() def __get__(self, instance, owner): return self.__class__(model_cls=owner) @@ -43,19 +45,56 @@ class QuerySet: def table(self): return self.model_cls.__table__ + def prefixed_columns(self, alias, table): + return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}') + for column in table.columns] + + def prefixed_table_name(self, alias, name): + return text(f'{name} {alias}_{name}') + + def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key): + return text(f'{alias}_{to_table}.{to_key}=' + f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') + def build_select_expression(self): tables = [self.table] + columns = list(self.table.columns) + order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] select_from = self.table for item in self._select_related: + previous_alias = '' + from_table = self.table.name + prev_model = self.model_cls model_cls = self.model_cls - select_from = self.table - for part in item.split("__"): - model_cls = model_cls.__model_fields__[part].to - select_from = sqlalchemy.sql.join(select_from, model_cls.__table__) - tables.append(model_cls.__table__) - expr = sqlalchemy.sql.select(tables) + for part in item.split("__"): + + model_cls = model_cls.__model_fields__[part].to + to_table = model_cls.__table__.name + + alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table) + + if prev_model.__model_fields__[part].virtual: + # TODO: change the key lookup + to_key = prev_model.__name__.lower() + from_key = model_cls.__pkname__ + else: + to_key = model_cls.__pkname__ + from_key = part + + on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) + target_table = self.prefixed_table_name(alias, to_table) + select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) + tables.append(model_cls.__table__) + order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) + columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + + expr = sqlalchemy.sql.select(columns) expr = expr.select_from(select_from) if self.filter_clauses: @@ -71,6 +110,9 @@ class QuerySet: if self.query_offset: expr = expr.offset(self.query_offset) + for order in order_bys: + expr = expr.order_by(order) + print(expr.compile(compile_kwargs={"literal_binds": True})) return expr @@ -83,6 +125,7 @@ class QuerySet: kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): + table_prefix = '' if "__" in key: parts = key.split("__") @@ -106,14 +149,22 @@ class QuerySet: # Walk the relationships to the actual model class # against which the comparison is being made. + previous_table = model_cls.__tablename__ for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table, + current_table) model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + print(table_prefix) + table = model_cls.__table__ column = model_cls.__table__.columns[field_name] else: op = "exact" column = self.table.columns[key] + table = self.table # Map the operation code onto SQLAlchemy's ColumnElement # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement @@ -134,6 +185,13 @@ class QuerySet: clause = getattr(column, op_attr)(value) clause.modifiers['escape'] = '\\' if has_escaped_character else None + + clause_text = str(clause.compile(compile_kwargs={"literal_binds": True})) + alias = f'{table_prefix}_' if table_prefix else '' + aliased_name = f'{alias}{table.name}.{column.name}' + clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name) + clause = text(clause_text) + filter_clauses.append(clause) return self.__class__( @@ -212,11 +270,36 @@ class QuerySet: expr = self.build_select_expression() rows = await self.database.fetch_all(expr) - return [ + result_rows = [ self.model_cls.from_row(row, select_related=self._select_related) for row in rows ] + result_rows = self.merge_result_rows(result_rows) + + return result_rows + + @classmethod + def merge_result_rows(cls, result_rows): + merged_rows = [] + for index, model in enumerate(result_rows): + if index > 0 and model.pk == result_rows[index - 1].pk: + result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + else: + merged_rows.append(model) + return merged_rows + + @classmethod + def merge_two_instances(cls, one: 'Model', other: 'Model'): + for field in one.__model_fields__.keys(): + print(field, one.dict(), other.dict()) + if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): + setattr(other, field, getattr(one, field) + getattr(other, field)) + elif isinstance(getattr(one, field), orm.models.Model): + if getattr(one, field).pk == getattr(other, field).pk: + setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field))) + return other + async def create(self, **kwargs): new_kwargs = dict(**kwargs) diff --git a/orm/relations.py b/orm/relations.py index 31bb520..0888284 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -2,7 +2,7 @@ import pprint import string import uuid from random import choices -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List from weakref import proxy from sqlalchemy import text @@ -40,6 +40,7 @@ class RelationshipManager: self._relations[reverse_key] = get_relation_config('reverse', table_name, field) def deregister(self, model: 'Model'): + print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: @@ -54,8 +55,25 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - self._relations[parent_name.title() + '_' + child_name + 's'].setdefault(parent_id, []).append(child) - self._relations[child_name.title() + '_' + parent_name].setdefault(child_id, []).append(parent) + print( + f'setting up relationship, {parent_id}, {child_id}, ' + f'{parent.__class__.__name__}, {child.__class__.__name__}, ' + f'{parent.pk if parent.values is not None else None}, ' + f'{child.pk if child.values is not None else None}') + parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) + self.append_related_model(parents_list, child) + children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) + self.append_related_model(children_list, parent) + + def append_related_model(self, relations_list: List['Model'], model: 'Model'): + for x in relations_list: + try: + if x.__same__(model): + return + except ReferenceError: + continue + + relations_list.append(model) def contains(self, relations_key: str, object: 'Model'): if relations_key in self._relations: @@ -69,6 +87,12 @@ class RelationshipManager: return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][object._orm_id] + def resolve_relation_join(self, from_table: str, to_table: str) -> str: + for k, v in self._relations.items(): + if v['source_table'] == from_table and v['target_table'] == to_table: + return self._relations[k]['table_alias'] + return '' + def __str__(self): # pragma no cover return pprint.pformat(self._relations, indent=4, width=1) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 155e6a4..30b34d5 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -9,6 +9,15 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() +class Department(orm.Model): + __tablename__ = "departments" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + class SchoolClass(orm.Model): __tablename__ = "schoolclasses" __metadata__ = metadata @@ -16,6 +25,7 @@ class SchoolClass(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) + department = orm.ForeignKey(Department) class Category(orm.Model): @@ -60,7 +70,8 @@ def create_test_database(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: - class1 = await SchoolClass.objects.create(name="Math") + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) @@ -80,7 +91,8 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_right_tables_join(): async with database: - class1 = await SchoolClass.objects.create(name="Math") + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") await Student.objects.create(name="Jane", category=category, schoolclass=class1) @@ -89,5 +101,25 @@ async def test_right_tables_join(): classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' - breakpoint() + assert classes[0].teachers[0].category.name == 'Domestic' + + assert classes[0].students[0].category.name is None + await classes[0].students[0].category.load() + assert classes[0].students[0].category.name == 'Foreign' + + +@pytest.mark.asyncio +async def test_multiple_reverse_related_objects(): + async with database: + department = await Department.objects.create(name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Jack", category=category, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + + classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() + assert classes[0].name == 'Math' + assert classes[0].students[0].name == 'Jane' assert classes[0].teachers[0].category.name == 'Domestic' From 3929dd6d73f60d7c52b87100e01a1f995ac14362 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 7 Aug 2020 15:21:37 +0200 Subject: [PATCH 26/62] all tests passes - creating dummy models if fk not nullable --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 22 +++++++++++++++++----- orm/queryset.py | 10 ++++++++-- tests/test_same_table_joins.py | 10 +++++----- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/.coverage b/.coverage index 955ea24ea0b5b017d6968d3bd41402d31daa70d6..807412ba39ac227fd469563a159bb1f53c3717b6 100644 GIT binary patch delta 190 zcmV;v073tNpaX!Q1F$MD2Q@k}Fgi6dvoSB#P#(Ad5BU%358w~G53vu04^$684>%7i z4+akK4$ls_vk?%K4i^;;1OW*w4sP@2oj31!&(D9y`R|?oJ(JRoemMS)_t*P<2`>i( z0SRCS+VVg4Z+ExpzOU7GyWO_!?(uHN-#PYf`}(;*@BR66cl+*Bzx}T~XAYBxk2WPE s1Ox#INCfT{``6v>-hKZ++3d|`v)ODmo6Tmw`T5=5zt{T)Ksw1?cK`qY delta 186 zcmV;r07d_RpaX!Q1F$MD2QxY`F*-CkvoSB#P#(Mh5BU%358w~I53>)44_6OE4?7Po z4+#$Q4$=<1vk?%O4i*#+1OW*u4sP@2oj31!&&PiO7yy6)lh2NRHy)1n*86=4E(Zhw z30?=<@;~-(cenYzR@?7(+qS#MyB&Y$*t_lP=l;C+=g-~kyH9;KfbyI^Hx=`|*0ev*3>)KoU|`sQ>@~ diff --git a/orm/fields.py b/orm/fields.py index 2723b82..69ba85f 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,14 +1,17 @@ import datetime import decimal -from typing import Optional, List +from typing import Optional, List, Type, TYPE_CHECKING -import orm import sqlalchemy from pydantic import Json from pydantic.fields import ModelField +import orm from orm.exceptions import ModelDefinitionError, RelationshipInstanceError +if TYPE_CHECKING: # pragma no cover + from orm.models import Model + class BaseField: __type__ = None @@ -173,8 +176,16 @@ class Decimal(BaseField): return sqlalchemy.DECIMAL(self.length, self.precision) +def create_dummy_instance(fk: Type['Model'], pk: int = None): + init_dict = {fk.__pkname__: pk or -1} + init_dict = {**init_dict, **{k: create_dummy_instance(v.to) + for k, v in fk.__model_fields__.items() + if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}} + return fk(**init_dict) + + class ForeignKey(BaseField): - def __init__(self, to, related_name: str = None, nullable: bool = False, virtual: bool = False): + def __init__(self, to, related_name: str = None, nullable: bool = True, virtual: bool = False): super().__init__(nullable=nullable) self.virtual = virtual self.related_name = related_name @@ -206,14 +217,15 @@ class ForeignKey(BaseField): elif isinstance(value, dict): model = self.to(**value) else: - model = self.to(**{self.to.__pkname__: value}) + model = create_dummy_instance(fk=self.to, pk=value) child_model_name = self.related_name or child.__class__.__name__.lower() + 's' model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), child.__class__.__name__.lower(), model, child, virtual=self.virtual) - if child_model_name not in model.__fields__: + if child_model_name not in model.__fields__ \ + and child.__class__.__name__.lower() not in model.__fields__: model.__fields__[child_model_name] = ModelField(name=child_model_name, type_=Optional[child.__pydantic_model__], model_config=child.__pydantic_model__.__config__, diff --git a/orm/queryset.py b/orm/queryset.py index 1f01d7d..2eba123 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -57,11 +57,17 @@ class QuerySet: f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') def build_select_expression(self): - tables = [self.table] + # tables = [self.table] columns = list(self.table.columns) order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] select_from = self.table + for key in self.model_cls.__model_fields__: + if not self.model_cls.__model_fields__[key].nullable \ + and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \ + and key not in self._select_related: + self._select_related.append(key) + for item in self._select_related: previous_alias = '' from_table = self.table.name @@ -86,7 +92,7 @@ class QuerySet: on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) target_table = self.prefixed_table_name(alias, to_table) select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) - tables.append(model_cls.__table__) + # tables.append(model_cls.__table__) order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) columns.extend(self.prefixed_columns(alias, model_cls.__table__)) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 30b34d5..f98e01b 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -14,7 +14,7 @@ class Department(orm.Model): __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) + id = orm.Integer(primary_key=True, autoincrement=False) name = orm.String(length=100) @@ -25,7 +25,7 @@ class SchoolClass(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) - department = orm.ForeignKey(Department) + department = orm.ForeignKey(Department, nullable=False) class Category(orm.Model): @@ -70,7 +70,7 @@ def create_test_database(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(): async with database: - department = await Department.objects.create(name='Math Department') + department = await Department.objects.create(id=1, name='Math Department') class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") @@ -91,7 +91,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio async def test_right_tables_join(): async with database: - department = await Department.objects.create(name='Math Department') + department = await Department.objects.create(id=1, name='Math Department') class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") @@ -111,7 +111,7 @@ async def test_right_tables_join(): @pytest.mark.asyncio async def test_multiple_reverse_related_objects(): async with database: - department = await Department.objects.create(name='Math Department') + department = await Department.objects.create(id=1, name='Math Department') class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") From 8f179f763f36bc4ed6d3b72a005242afab58d2de Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 7 Aug 2020 19:34:17 +0200 Subject: [PATCH 27/62] add preloading of not nullable relations (and all chain inbetween) --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 8 +- orm/queryset.py | 152 ++++++++++++++++++++++++--------- orm/relations.py | 12 +-- tests/test_same_table_joins.py | 9 +- 5 files changed, 128 insertions(+), 53 deletions(-) diff --git a/.coverage b/.coverage index 807412ba39ac227fd469563a159bb1f53c3717b6..1cc206f5c0dcc201cc45878a1905994243a1dd21 100644 GIT binary patch delta 165 zcmV;W09yZmpaX!Q1F$MD2RS-3GdeUmvoSB#P#&%T5BU%358w~G53vu04@(a>4=)cQ z4*d@24#p0yvk?%A4j2{=1OW*y4sP@2oj31!&(D7c&wuay?*Ws~j(tM){(|?{`+W&G z2Lu5LatGSwe|+}0+ud!hZ-4K;bq^1E-+$X~+xWVByxZ|-j=kHye(ukEfBxLvzWdZ~ T_bbnt1Cxc1S^+P!;Ey0c=M+v4 delta 155 zcmV;M0A&AwpaX!Q1F$MD2Q@k}Fgi6dvoSB#P#(Ad5BU%358w~G53vu04^$684>%7i z4+akK4$ls_vk?%K4i^;;1OW*w4sP@2oj31!&(D9y`R|?oJ(I?ceL4O9j`!F5eF-lI z1OW+P2io#K_HTE$>AtVkcDvoS?e6hz$KN^jZu|PVKkxndb9ejhQ@{PMJZBD*g^yYR J3$x&lAV3tJNV@<4 diff --git a/orm/fields.py b/orm/fields.py index 69ba85f..aa40930 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -185,8 +185,8 @@ def create_dummy_instance(fk: Type['Model'], pk: int = None): class ForeignKey(BaseField): - def __init__(self, to, related_name: str = None, nullable: bool = True, virtual: bool = False): - super().__init__(nullable=nullable) + def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False): + super().__init__(nullable=nullable, name=name) self.virtual = virtual self.related_name = related_name self.to = to @@ -230,6 +230,8 @@ class ForeignKey(BaseField): type_=Optional[child.__pydantic_model__], model_config=child.__pydantic_model__.__config__, class_validators=child.__pydantic_model__.__validators__) - model.__model_fields__[child_model_name] = ForeignKey(child.__class__, virtual=True) + model.__model_fields__[child_model_name] = ForeignKey(child.__class__, + name=child_model_name, + virtual=True) return model diff --git a/orm/queryset.py b/orm/queryset.py index 2eba123..95ad8f6 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,9 +1,10 @@ -from typing import List, TYPE_CHECKING, Type +from typing import List, TYPE_CHECKING, Type, NamedTuple import sqlalchemy from sqlalchemy import text import orm +from orm import ForeignKey from orm.exceptions import NoMatch, MultipleMatches if TYPE_CHECKING: # pragma no cover @@ -22,6 +23,13 @@ FILTER_OPERATORS = { } +class JoinParameters(NamedTuple): + prev_model: Type['Model'] + previous_alias: str + from_table: str + model_cls: Type['Model'] + + class QuerySet: ESCAPE_CHARACTERS = ['%', '_'] @@ -32,7 +40,13 @@ class QuerySet: self._select_related = [] if select_related is None else select_related self.limit_count = limit_count self.query_offset = offset - self.aliases_dict = dict() + + self.auto_related = [] + self.used_aliases = [] + + self.select_from = None + self.columns = None + self.order_bys = None def __get__(self, instance, owner): return self.__class__(model_cls=owner) @@ -56,52 +70,101 @@ class QuerySet: return text(f'{alias}_{to_table}.{to_key}=' f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') + def build_join_parameters(self, part, join_params: JoinParameters): + model_cls = join_params.model_cls.__model_fields__[part].to + to_table = model_cls.__table__.name + + alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table) + # print(f'resolving tables alias from {join_params.from_table}, to: {to_table} -> {alias}') + if alias not in self.used_aliases: + if join_params.prev_model.__model_fields__[part].virtual: + to_key = next((v for k, v in model_cls.__model_fields__.items() + if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name + from_key = model_cls.__pkname__ + else: + to_key = model_cls.__pkname__ + from_key = part + + on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key, + from_key) + target_table = self.prefixed_table_name(alias, to_table) + self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause) + self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) + self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.used_aliases.append(alias) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + return JoinParameters(prev_model, previous_alias, from_table, model_cls) + + @staticmethod + def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool: + return isinstance(field, ForeignKey) and field_name not in rel_part + + def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool: + prev_part_of_related = "__".join(rel_part.split("__")[:-1]) + partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related]) + already_checked = any([x.startswith(rel_part) for x in self.auto_related]) + return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested + + def extract_auto_required_relations(self, join_params: JoinParameters, + rel_part: str = '', nested: bool = False, parent_virtual: bool = False): + # print(f'checking model {join_params.prev_model}, {rel_part}') + for field_name, field in join_params.prev_model.__model_fields__.items(): + # print(f'checking_field {field_name}') + if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part): + rel_part = field_name if not rel_part else rel_part + '__' + field_name + if not field.nullable: + # print(f'field {field_name} is not nullable, appending to auto, curr rel: {rel_part}') + if rel_part not in self._select_related: + self.auto_related.append("__".join(rel_part.split("__")[:-1])) + rel_part = '' + elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part): + # print( + # f'field {field_name} is nullable, going down, curr rel: ' + # f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, ' + # f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}') + join_params = JoinParameters(field.to, join_params.previous_alias, + join_params.from_table, join_params.prev_model) + self.extract_auto_required_relations(join_params=join_params, + rel_part=rel_part, nested=True, parent_virtual=field.virtual) + else: + # print( + # f'field {field_name} is out, going down, curr rel: ' + # f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, ' + # f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}') + rel_part = '' + def build_select_expression(self): - # tables = [self.table] - columns = list(self.table.columns) - order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] - select_from = self.table + self.columns = list(self.table.columns) + self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] + self.select_from = self.table for key in self.model_cls.__model_fields__: if not self.model_cls.__model_fields__[key].nullable \ and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \ and key not in self._select_related: - self._select_related.append(key) + self._select_related = [key] + self._select_related + + start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) + self.extract_auto_required_relations(start_params) + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + self._select_related.sort(key=lambda item: (-len(item), item)) for item in self._select_related: - previous_alias = '' - from_table = self.table.name - prev_model = self.model_cls - model_cls = self.model_cls + join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) for part in item.split("__"): + join_parameters = self.build_join_parameters(part, join_parameters) - model_cls = model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name - - alias = model_cls._orm_relationship_manager.resolve_relation_join(from_table, to_table) - - if prev_model.__model_fields__[part].virtual: - # TODO: change the key lookup - to_key = prev_model.__name__.lower() - from_key = model_cls.__pkname__ - else: - to_key = model_cls.__pkname__ - from_key = part - - on_clause = self.on_clause(from_table, to_table, previous_alias, alias, to_key, from_key) - target_table = self.prefixed_table_name(alias, to_table) - select_from = sqlalchemy.sql.outerjoin(select_from, target_table, on_clause) - # tables.append(model_cls.__table__) - order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) - columns.extend(self.prefixed_columns(alias, model_cls.__table__)) - - previous_alias = alias - from_table = to_table - prev_model = model_cls - - expr = sqlalchemy.sql.select(columns) - expr = expr.select_from(select_from) + expr = sqlalchemy.sql.select(self.columns) + expr = expr.select_from(self.select_from) if self.filter_clauses: if len(self.filter_clauses) == 1: @@ -116,10 +179,17 @@ class QuerySet: if self.query_offset: expr = expr.offset(self.query_offset) - for order in order_bys: + for order in self.order_bys: expr = expr.order_by(order) - print(expr.compile(compile_kwargs={"literal_binds": True})) + # print(expr.compile(compile_kwargs={"literal_binds": True})) + + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] + return expr def filter(self, **kwargs): @@ -163,7 +233,7 @@ class QuerySet: model_cls = model_cls.__model_fields__[part].to previous_table = current_table - print(table_prefix) + # print(table_prefix) table = model_cls.__table__ column = model_cls.__table__.columns[field_name] @@ -298,7 +368,7 @@ class QuerySet: @classmethod def merge_two_instances(cls, one: 'Model', other: 'Model'): for field in one.__model_fields__.keys(): - print(field, one.dict(), other.dict()) + # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), orm.models.Model): diff --git a/orm/relations.py b/orm/relations.py index 0888284..a1a6625 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -40,7 +40,7 @@ class RelationshipManager: self._relations[reverse_key] = get_relation_config('reverse', table_name, field) def deregister(self, model: 'Model'): - print(f'deregistering {model.__class__.__name__}, {model._orm_id}') + # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: @@ -55,11 +55,11 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - print( - f'setting up relationship, {parent_id}, {child_id}, ' - f'{parent.__class__.__name__}, {child.__class__.__name__}, ' - f'{parent.pk if parent.values is not None else None}, ' - f'{child.pk if child.values is not None else None}') + # print( + # f'setting up relationship, {parent_id}, {child_id}, ' + # f'{parent.__class__.__name__}, {child.__class__.__name__}, ' + # f'{parent.pk if parent.values is not None else None}, ' + # f'{child.pk if child.values is not None else None}') parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) self.append_related_model(parents_list, child) children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index f98e01b..ca4e7b7 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -77,15 +77,18 @@ async def test_model_multiple_instances_of_same_table_in_schema(): await Student.objects.create(name="Jane", category=category, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers', 'students']).all() + classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' # related fields of main model are only populated by pk + # unless there is a required foreign key somewhere along the way + # since department is required for schoolclass it was pre loaded (again) # but you can load them anytime - assert classes[0].students[0].schoolclass.name is None - await classes[0].students[0].schoolclass.load() assert classes[0].students[0].schoolclass.name == 'Math' + assert classes[0].students[0].schoolclass.department.name is None + await classes[0].students[0].schoolclass.department.load() + assert classes[0].students[0].schoolclass.department.name == 'Math Department' @pytest.mark.asyncio From 39e44b1985602ac05288c3eafa65e0cead661209 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 06:24:22 +0200 Subject: [PATCH 28/62] add dialect to compilation of sqlalchemy clauses --- .coverage | Bin 53248 -> 53248 bytes orm/__init__.py | 1 + orm/models.py | 4 ++++ orm/queryset.py | 3 ++- 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.coverage b/.coverage index 1cc206f5c0dcc201cc45878a1905994243a1dd21..c907f3e5af7db5ef5c792c69402959e4bfe2988f 100644 GIT binary patch delta 95 zcmV-l0HFVXpaX!Q1F$DA200)wHaar1EiXo3K7Re&pZEU!xx0P$so(Bbo-+q31_S{K zRtA3W<(=REbNSi4`PU=^(0lvn`kw7&`*#-{?CyTM_xJbT`+Xmiv5z Date: Sun, 9 Aug 2020 06:51:12 +0200 Subject: [PATCH 29/62] some style corrections --- .flake8 | 4 ++++ orm/__init__.py | 12 ++++++++---- orm/fields.py | 50 +++++++++++++++++++++++------------------------- orm/queryset.py | 12 ------------ orm/relations.py | 7 ------- requirements.txt | 8 +++++++- 6 files changed, 43 insertions(+), 50 deletions(-) create mode 100644 .flake8 diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..c2bfa9d --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +ignore = ANN101 +max-complexity = 10 +exclude = p38venv,.pytest_cache diff --git a/orm/__init__.py b/orm/__init__.py index 1ceb7b0..c947954 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -1,7 +1,7 @@ -from orm.fields import Integer, BigInteger, Boolean, Time, Text, String, JSON, DateTime, Date, Decimal, Float, \ - ForeignKey +from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch +from orm.fields import BigInteger, Boolean, Date, DateTime, Decimal, Float, ForeignKey, Integer, JSON, String, Text, \ + Time from orm.models import Model -from orm.exceptions import ModelDefinitionError, MultipleMatches, NoMatch, ModelNotSet __version__ = "0.0.1" __all__ = [ @@ -17,5 +17,9 @@ __all__ = [ "Decimal", "Float", "ForeignKey", - "Model" + "Model", + "ModelDefinitionError", + "ModelNotSet", + "MultipleMatches", + "NoMatch" ] diff --git a/orm/fields.py b/orm/fields.py index aa40930..80814f0 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,9 +1,9 @@ import datetime import decimal -from typing import Optional, List, Type, TYPE_CHECKING +from typing import List, Optional, TYPE_CHECKING, Type, Any, Union import sqlalchemy -from pydantic import Json +from pydantic import Json, BaseModel from pydantic.fields import ModelField import orm @@ -22,9 +22,7 @@ class BaseField: if args: if isinstance(args[0], str): if name is not None: - raise ModelDefinitionError( - 'Column name cannot be passed positionally and as a keyword.' - ) + raise ModelDefinitionError('Column name cannot be passed positionally and as a keyword.') name = args.pop(0) self.name = name @@ -43,20 +41,20 @@ class BaseField: raise ModelDefinitionError('Primary key column cannot be pydantic only.') @property - def is_required(self): + def is_required(self) -> bool: return not self.nullable and not self.has_default and not self.is_auto_primary_key @property - def default_value(self): + def default_value(self) -> Any: default = self.default if self.default is not None else self.server_default return default() if callable(default) else default @property - def has_default(self): + def has_default(self) -> bool: return self.default is not None or self.server_default is not None @property - def is_auto_primary_key(self): + def is_auto_primary_key(self) -> bool: if self.primary_key: return self.autoincrement return False @@ -83,7 +81,7 @@ class BaseField: def get_constraints(self) -> Optional[List]: return [] - def expand_relationship(self, value, child): + def expand_relationship(self, value, child) -> Any: return value @@ -95,70 +93,70 @@ class String(BaseField): self.length = kwargs.pop('length') super().__init__(*args, **kwargs) - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.String(self.length) class Integer(BaseField): __type__ = int - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.Integer() class Text(BaseField): __type__ = str - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.Text() class Float(BaseField): __type__ = float - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.Float() class Boolean(BaseField): __type__ = bool - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.Boolean() class DateTime(BaseField): __type__ = datetime.datetime - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.DateTime() class Date(BaseField): __type__ = datetime.date - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.Date() class Time(BaseField): __type__ = datetime.time - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.Time() class JSON(BaseField): __type__ = Json - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.JSON() class BigInteger(BaseField): __type__ = int - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.BigInteger() @@ -172,11 +170,11 @@ class Decimal(BaseField): self.precision = kwargs.pop('precision') super().__init__(*args, **kwargs) - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.DECIMAL(self.length, self.precision) -def create_dummy_instance(fk: Type['Model'], pk: int = None): +def create_dummy_instance(fk: Type['Model'], pk: int = None) -> 'Model': init_dict = {fk.__pkname__: pk or -1} init_dict = {**init_dict, **{k: create_dummy_instance(v.to) for k, v in fk.__model_fields__.items() @@ -192,18 +190,18 @@ class ForeignKey(BaseField): self.to = to @property - def __type__(self): + def __type__(self) -> Type[BaseModel]: return self.to.__pydantic_model__ - def get_constraints(self): + def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: fk_string = self.to.__tablename__ + "." + self.to.__pkname__ return [sqlalchemy.schema.ForeignKey(fk_string)] - def get_column_type(self): + def get_column_type(self) -> sqlalchemy.Column: to_column = self.to.__model_fields__[self.to.__pkname__] return to_column.get_column_type() - def expand_relationship(self, value, child): + def expand_relationship(self, value, child) -> Union['Model', List['Model']]: if not isinstance(value, (self.to, dict, int, str, list)) or ( isinstance(value, orm.models.Model) and not isinstance(value, self.to)): raise RelationshipInstanceError( diff --git a/orm/queryset.py b/orm/queryset.py index 63ba69a..36160e8 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -75,7 +75,6 @@ class QuerySet: to_table = model_cls.__table__.name alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table) - # print(f'resolving tables alias from {join_params.from_table}, to: {to_table} -> {alias}') if alias not in self.used_aliases: if join_params.prev_model.__model_fields__[part].virtual: to_key = next((v for k, v in model_cls.__model_fields__.items() @@ -110,30 +109,19 @@ class QuerySet: def extract_auto_required_relations(self, join_params: JoinParameters, rel_part: str = '', nested: bool = False, parent_virtual: bool = False): - # print(f'checking model {join_params.prev_model}, {rel_part}') for field_name, field in join_params.prev_model.__model_fields__.items(): - # print(f'checking_field {field_name}') if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part): rel_part = field_name if not rel_part else rel_part + '__' + field_name if not field.nullable: - # print(f'field {field_name} is not nullable, appending to auto, curr rel: {rel_part}') if rel_part not in self._select_related: self.auto_related.append("__".join(rel_part.split("__")[:-1])) rel_part = '' elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part): - # print( - # f'field {field_name} is nullable, going down, curr rel: ' - # f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, ' - # f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}') join_params = JoinParameters(field.to, join_params.previous_alias, join_params.from_table, join_params.prev_model) self.extract_auto_required_relations(join_params=join_params, rel_part=rel_part, nested=True, parent_virtual=field.virtual) else: - # print( - # f'field {field_name} is out, going down, curr rel: ' - # f'{rel_part}, nested:{nested}, virtual:{field.virtual}, virtual_par:{parent_virtual}, ' - # f'injoin: {"__".join(rel_part.split("__")[:-1]) in self._select_related}') rel_part = '' def build_select_expression(self): diff --git a/orm/relations.py b/orm/relations.py index a1a6625..a9a6971 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -5,8 +5,6 @@ from random import choices from typing import TYPE_CHECKING, List from weakref import proxy -from sqlalchemy import text - from orm.fields import ForeignKey if TYPE_CHECKING: # pragma no cover @@ -55,11 +53,6 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - # print( - # f'setting up relationship, {parent_id}, {child_id}, ' - # f'{parent.__class__.__name__}, {child.__class__.__name__}, ' - # f'{parent.pk if parent.values is not None else None}, ' - # f'{child.pk if child.values is not None else None}') parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) self.append_related_model(parents_list, child) children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) diff --git a/requirements.txt b/requirements.txt index db5ae51..fde4a73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,10 @@ pytest pytest-cov codecov pytest-asyncio -fastapi \ No newline at end of file +fastapi +flake8 +flake8-black +flake8-bugbear +flake8-import-order +flake8-bandit +flake8-annotations \ No newline at end of file From 241628b1d94398bc270cbe0f01eb715a7f8be6f7 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 07:51:06 +0200 Subject: [PATCH 30/62] liniting and applying black --- .coverage | Bin 53248 -> 53248 bytes .flake8 | 3 +- orm/__init__.py | 18 ++- orm/fields.py | 133 +++++++++++------ orm/helpers.py | 26 ---- orm/models.py | 165 +++++++++++++-------- orm/queryset.py | 261 +++++++++++++++++++++++---------- orm/relations.py | 80 ++++++---- tests/test_model_definition.py | 16 ++ 9 files changed, 455 insertions(+), 247 deletions(-) delete mode 100644 orm/helpers.py diff --git a/.coverage b/.coverage index c907f3e5af7db5ef5c792c69402959e4bfe2988f..dc4c4fa39b8c2f40a9c341058db4ebd6e14f7779 100644 GIT binary patch delta 402 zcmZozz}&Ead4rZdySbIIv6YG8W<&jZ0*ro}1sufqnbcS(`O9*rWv1q&6zdgKYOzeN z_mkw!FUpNctw>HSD9Oyvn_T8UmD!(PeKTKxB>&|9adKR6<*bwI;-o}TRD(=m(PR#v z+!!amd1`#2g4|>V{-6A>`0w#w;a|mH!EeGZ$1ldu$@hfs4Bu+L$(scQ^7$sO>$4OQ z<6&Xs)Z;0e{(t_@IiK%8duC~GKlx{$UA+hgP(q8N?&zP~o%dgU{r~Ux`tQGmoY+|z zIU8Bncb)$KW6l5DZ{OW~{4Mh?~}|M@q6_UX+x&;Fg2!cZ0e`^B$)H_dL| z{`skzpZ~>c7FH!|LU{%o=w|)GskT9*|hi1 n=D(>g|9*DwzdPoP94tVULM&-%X`Xw9y90s{e delta 341 zcmV-b0jmChpaX!Q1F$MD2R1q~GdeUmvoSB;5CKxN5I`0W0xAWQP){lXCIpgC9R+4- zWo%@Vc2AW9Cl6Dz4p12nlTTkO0R@x)ULFl*X=Q9=b1ras1StbolYw3rvz1?QARev& z5BU%358w~G53vu04@(a>4=)cQ4*d@24#p0yvk?%A4wI9PIT92M1OW*u40f0QpZT2g z-{$5yli-d%e-a1;0SP7ueiide`}O{U_t*P<2{;D?0SR&k+U0+I_P5*JZLV*B@4j^p z4|?B!+iu(Vx_i9a@n?>`+kXArpZEU!xx0P$so(Bbo-+q31_S{KRtA3W<(=REbNSi4 z`PU=^(0lvn`kw7&`*#-{?CyTM_xJbT`+Xk>1q1;JG7<%5W@aM<1OW+11nw65*WK>k neg8k%?9FDg*=#nO&1S#(`Q6>W*ZVIB0|WsH5(C<^=#Lmc2E>+& diff --git a/.flake8 b/.flake8 index c2bfa9d..ec05a50 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] -ignore = ANN101 +ignore = ANN101, ANN102, W503 max-complexity = 10 +max-line-length = 88 exclude = p38venv,.pytest_cache diff --git a/orm/__init__.py b/orm/__init__.py index c947954..39adee8 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -1,6 +1,18 @@ from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch -from orm.fields import BigInteger, Boolean, Date, DateTime, Decimal, Float, ForeignKey, Integer, JSON, String, Text, \ - Time +from orm.fields import ( + BigInteger, + Boolean, + Date, + DateTime, + Decimal, + Float, + ForeignKey, + Integer, + JSON, + String, + Text, + Time, +) from orm.models import Model __version__ = "0.0.1" @@ -21,5 +33,5 @@ __all__ = [ "ModelDefinitionError", "ModelNotSet", "MultipleMatches", - "NoMatch" + "NoMatch", ] diff --git a/orm/fields.py b/orm/fields.py index 80814f0..9c9f3f7 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,14 +1,15 @@ import datetime import decimal -from typing import List, Optional, TYPE_CHECKING, Type, Any, Union - -import sqlalchemy -from pydantic import Json, BaseModel -from pydantic.fields import ModelField +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import orm from orm.exceptions import ModelDefinitionError, RelationshipInstanceError +from pydantic import BaseModel, Json +from pydantic.fields import ModelField + +import sqlalchemy + if TYPE_CHECKING: # pragma no cover from orm.models import Model @@ -16,33 +17,39 @@ if TYPE_CHECKING: # pragma no cover class BaseField: __type__ = None - def __init__(self, *args, **kwargs) -> None: - name = kwargs.pop('name', None) + def __init__(self, *args: Any, **kwargs: Any) -> None: + name = kwargs.pop("name", None) args = list(args) if args: if isinstance(args[0], str): if name is not None: - raise ModelDefinitionError('Column name cannot be passed positionally and as a keyword.') + raise ModelDefinitionError( + "Column name cannot be passed positionally and as a keyword." + ) name = args.pop(0) self.name = name - self.primary_key = kwargs.pop('primary_key', False) - self.autoincrement = kwargs.pop('autoincrement', self.primary_key and self.__type__ == int) + self.primary_key = kwargs.pop("primary_key", False) + self.autoincrement = kwargs.pop( + "autoincrement", self.primary_key and self.__type__ == int + ) - self.nullable = kwargs.pop('nullable', not self.primary_key) - self.default = kwargs.pop('default', None) - self.server_default = kwargs.pop('server_default', None) + self.nullable = kwargs.pop("nullable", not self.primary_key) + self.default = kwargs.pop("default", None) + self.server_default = kwargs.pop("server_default", None) - self.index = kwargs.pop('index', None) - self.unique = kwargs.pop('unique', None) + self.index = kwargs.pop("index", None) + self.unique = kwargs.pop("unique", None) - self.pydantic_only = kwargs.pop('pydantic_only', False) + self.pydantic_only = kwargs.pop("pydantic_only", False) if self.pydantic_only and self.primary_key: - raise ModelDefinitionError('Primary key column cannot be pydantic only.') + raise ModelDefinitionError("Primary key column cannot be pydantic only.") @property def is_required(self) -> bool: - return not self.nullable and not self.has_default and not self.is_auto_primary_key + return ( + not self.nullable and not self.has_default and not self.is_auto_primary_key + ) @property def default_value(self) -> Any: @@ -81,16 +88,19 @@ class BaseField: def get_constraints(self) -> Optional[List]: return [] - def expand_relationship(self, value, child) -> Any: + def expand_relationship(self, value: Any, child: "Model") -> Any: return value class String(BaseField): __type__ = str - def __init__(self, *args, **kwargs): - assert 'length' in kwargs, 'length is required' - self.length = kwargs.pop('length') + def __init__(self, *args: Any, **kwargs: Any) -> None: + if "length" not in kwargs: + raise ModelDefinitionError( + "Param length is required for String model field." + ) + self.length = kwargs.pop("length") super().__init__(*args, **kwargs) def get_column_type(self) -> sqlalchemy.Column: @@ -163,27 +173,41 @@ class BigInteger(BaseField): class Decimal(BaseField): __type__ = decimal.Decimal - def __init__(self, *args, **kwargs): - assert 'precision' in kwargs, 'precision is required' - assert 'length' in kwargs, 'length is required' - self.length = kwargs.pop('length') - self.precision = kwargs.pop('precision') + def __init__(self, *args: Any, **kwargs: Any) -> None: + if "length" not in kwargs or "precision" not in kwargs: + raise ModelDefinitionError( + "Params length and precision are required for Decimal model field." + ) + self.length = kwargs.pop("length") + self.precision = kwargs.pop("precision") super().__init__(*args, **kwargs) def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.DECIMAL(self.length, self.precision) -def create_dummy_instance(fk: Type['Model'], pk: int = None) -> 'Model': +def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": init_dict = {fk.__pkname__: pk or -1} - init_dict = {**init_dict, **{k: create_dummy_instance(v.to) - for k, v in fk.__model_fields__.items() - if isinstance(v, ForeignKey) and not v.nullable and not v.virtual}} + init_dict = { + **init_dict, + **{ + k: create_dummy_instance(v.to) + for k, v in fk.__model_fields__.items() + if isinstance(v, ForeignKey) and not v.nullable and not v.virtual + }, + } return fk(**init_dict) class ForeignKey(BaseField): - def __init__(self, to, name: str = None, related_name: str = None, nullable: bool = True, virtual: bool = False): + def __init__( + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, + ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual self.related_name = related_name @@ -201,11 +225,16 @@ class ForeignKey(BaseField): to_column = self.to.__model_fields__[self.to.__pkname__] return to_column.get_column_type() - def expand_relationship(self, value, child) -> Union['Model', List['Model']]: + def expand_relationship( + self, value: Any, child: "Model" + ) -> Union["Model", List["Model"]]: if not isinstance(value, (self.to, dict, int, str, list)) or ( - isinstance(value, orm.models.Model) and not isinstance(value, self.to)): + isinstance(value, orm.models.Model) and not isinstance(value, self.to) + ): raise RelationshipInstanceError( - 'Relationship model can be build only from orm.Model, dict and integer or string (pk).') + "Relationship model can be build only from orm.Model, " + "dict and integer or string (pk)." + ) if isinstance(value, list) and not isinstance(value, self.to): model = [self.expand_relationship(val, child) for val in value] return model @@ -217,19 +246,27 @@ class ForeignKey(BaseField): else: model = create_dummy_instance(fk=self.to, pk=value) - child_model_name = self.related_name or child.__class__.__name__.lower() + 's' - model._orm_relationship_manager.add_relation(model.__class__.__name__.lower(), - child.__class__.__name__.lower(), - model, child, virtual=self.virtual) + child_model_name = self.related_name or child.__class__.__name__.lower() + "s" + model._orm_relationship_manager.add_relation( + model.__class__.__name__.lower(), + child.__class__.__name__.lower(), + model, + child, + virtual=self.virtual, + ) - if child_model_name not in model.__fields__ \ - and child.__class__.__name__.lower() not in model.__fields__: - model.__fields__[child_model_name] = ModelField(name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__) - model.__model_fields__[child_model_name] = ForeignKey(child.__class__, - name=child_model_name, - virtual=True) + if ( + child_model_name not in model.__fields__ + and child.__class__.__name__.lower() not in model.__fields__ + ): + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child.__class__, name=child_model_name, virtual=True + ) return model diff --git a/orm/helpers.py b/orm/helpers.py deleted file mode 100644 index f8d7cfb..0000000 --- a/orm/helpers.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Union, Set, Dict # pragma no cover - - -class Excludable: # pragma no cover - - @staticmethod - def get_excluded(exclude: Union[Set, Dict, None], key: str = None): - # print(f'checking excluded for {key}', exclude) - if isinstance(exclude, dict): - if isinstance(exclude.get(key, {}), dict) and '__all__' in exclude.get(key, {}).keys(): - return exclude.get(key).get('__all__') - return exclude.get(key, {}) - return exclude - - @staticmethod - def is_excluded(exclude: Union[Set, Dict, None], key: str = None): - if exclude is None: - return False - to_exclude = Excludable.get_excluded(exclude, key) - # print(f'to exclude for current key = {key}', to_exclude) - - if isinstance(to_exclude, Set): - return key in to_exclude - elif to_exclude is ...: - return True - return False diff --git a/orm/models.py b/orm/models.py index b3f7d78..68e32f4 100644 --- a/orm/models.py +++ b/orm/models.py @@ -2,35 +2,39 @@ import copy import inspect import json import uuid -from typing import Any, List, Type, TYPE_CHECKING, Optional, TypeVar, Tuple -from typing import Set, Dict +from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar +from typing import Callable, Dict, Set import databases -import pydantic -import sqlalchemy -from pydantic import BaseModel, BaseConfig, create_model import orm.queryset as qry from orm.exceptions import ModelDefinitionError from orm.fields import BaseField, ForeignKey from orm.relations import RelationshipManager +import pydantic +from pydantic import BaseConfig, BaseModel, create_model + +import sqlalchemy + relationship_manager = RelationshipManager() def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: - pydantic_fields = {field_name: ( - base_field.__type__, - ... if base_field.is_required else base_field.default_value - ) + pydantic_fields = { + field_name: ( + base_field.__type__, + ... if base_field.is_required else base_field.default_value, + ) for field_name, base_field in object_dict.items() - if isinstance(base_field, BaseField)} + if isinstance(base_field, BaseField) + } return pydantic_fields -def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename: str) -> Tuple[Optional[str], - List[sqlalchemy.Column], - Dict[str, BaseField]]: +def sqlalchemy_columns_from_model_fields( + name: str, object_dict: Dict, tablename: str +) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] model_fields: Dict[str, BaseField] = {} @@ -42,9 +46,16 @@ def sqlalchemy_columns_from_model_fields(name: str, object_dict: Dict, tablename if field.primary_key: pkname = field_name if isinstance(field, ForeignKey): - reverse_name = field.related_name or field.to.__name__.lower().title() + '_' + name.lower() + 's' - relation_name = name.lower().title() + '_' + field.to.__name__.lower() - relationship_manager.add_relation_type(relation_name, reverse_name, field, tablename) + reverse_name = ( + field.related_name + or field.to.__name__.lower().title() + "_" + name.lower() + "s" + ) + relation_name = ( + name.lower().title() + "_" + field.to.__name__.lower() + ) + relationship_manager.add_relation_type( + relation_name, reverse_name, field, tablename + ) columns.append(field.get_column(field_name)) return pkname, columns, model_fields @@ -57,9 +68,7 @@ def get_pydantic_base_orm_config() -> Type[BaseConfig]: class ModelMetaclass(type): - def __new__( - mcs: type, name: str, bases: Any, attrs: dict - ) -> type: + def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) @@ -71,25 +80,29 @@ class ModelMetaclass(type): metadata = attrs["__metadata__"] # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields(name, attrs, tablename) - attrs['__table__'] = sqlalchemy.Table(tablename, metadata, *columns) - attrs['__columns__'] = columns - attrs['__pkname__'] = pkname + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, tablename + ) + attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) + attrs["__columns__"] = columns + attrs["__pkname__"] = pkname if not pkname: - raise ModelDefinitionError('Table has to have a primary key.') + raise ModelDefinitionError("Table has to have a primary key.") # pydantic model creation pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model(name, __config__=get_pydantic_base_orm_config(), **pydantic_fields) - attrs['__pydantic_fields__'] = pydantic_fields - attrs['__pydantic_model__'] = pydantic_model - attrs['__fields__'] = copy.deepcopy(pydantic_model.__fields__) - attrs['__signature__'] = copy.deepcopy(pydantic_model.__signature__) - attrs['__annotations__'] = copy.deepcopy(pydantic_model.__annotations__) - attrs['__model_fields__'] = model_fields + pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **pydantic_fields + ) + attrs["__pydantic_fields__"] = pydantic_fields + attrs["__pydantic_model__"] = pydantic_model + attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) + attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) + attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + attrs["__model_fields__"] = model_fields - attrs['_orm_relationship_manager'] = relationship_manager + attrs["_orm_relationship_manager"] = relationship_manager new_model = super().__new__( # type: ignore mcs, name, bases, attrs @@ -99,7 +112,8 @@ class ModelMetaclass(type): class Model(list, metaclass=ModelMetaclass): - # Model inherits from list in order to be treated as request.Body parameter in fastapi routes, + # Model inherits from list in order to be treated as + # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts __abstract__ = True if TYPE_CHECKING: # pragma no cover @@ -115,17 +129,20 @@ class Model(list, metaclass=ModelMetaclass): objects = qry.QuerySet() - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False self.values: Optional[BaseModel] = None if "pk" in kwargs: kwargs[self.__pkname__] = kwargs.pop("pk") - kwargs = {k: self.__model_fields__[k].expand_relationship(v, self) for k, v in kwargs.items()} + kwargs = { + k: self.__model_fields__[k].expand_relationship(v, self) + for k, v in kwargs.items() + } self.values = self.__pydantic_model__(**kwargs) - def __del__(self): + def __del__(self) -> None: self._orm_relationship_manager.deregister(self) def __setattr__(self, key: str, value: Any) -> None: @@ -138,20 +155,24 @@ class Model(list, metaclass=ModelMetaclass): value = self.__model_fields__[key].expand_relationship(value, self) - relation_key = self.__class__.__name__.title() + '_' + key + relation_key = self.__class__.__name__.title() + "_" + key if not self._orm_relationship_manager.contains(relation_key, self): setattr(self.values, key, value) else: super().__setattr__(key, value) def __getattribute__(self, key: str) -> Any: - if key != '__fields__' and key in self.__fields__: - relation_key = self.__class__.__name__.title() + '_' + key + if key != "__fields__" and key in self.__fields__: + relation_key = self.__class__.__name__.title() + "_" + key if self._orm_relationship_manager.contains(relation_key, self): return self._orm_relationship_manager.get(relation_key, self) item = getattr(self.values, key, None) - if item is not None and self.is_conversion_to_json_needed(key) and isinstance(item, str): + if ( + item is not None + and self.is_conversion_to_json_needed(key) + and isinstance(item, str) + ): try: item = json.loads(item) except TypeError: # pragma no cover @@ -159,30 +180,41 @@ class Model(list, metaclass=ModelMetaclass): return item return super().__getattribute__(key) - def __eq__(self, other): + def __eq__(self, other: "Model") -> bool: return self.values.dict() == other.values.dict() - def __same__(self, other): - assert self.__class__ == other.__class__ + def __same__(self, other: "Model") -> bool: + if self.__class__ != other.__class__: + return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk) + self.values is not None and other.values is not None and self.pk == other.pk + ) - def __repr__(self): # pragma no cover + def __repr__(self) -> str: # pragma no cover return self.values.__repr__() @classmethod - def from_row(cls, row, select_related: List = None, previous_table: str = None) -> 'Model': + def from_row( + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, + ) -> "Model": item = {} select_related = select_related or [] - table_prefix = cls._orm_relationship_manager.resolve_relation_join(previous_table, cls.__table__.name) + table_prefix = cls._orm_relationship_manager.resolve_relation_join( + previous_table, cls.__table__.name + ) previous_table = cls.__table__.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) model_cls = cls.__model_fields__[first_part].to - child = model_cls.from_row(row, select_related=[remainder], previous_table=previous_table) + child = model_cls.from_row( + row, select_related=[remainder], previous_table=previous_table + ) item[first_part] = child else: model_cls = cls.__model_fields__[related].to @@ -191,7 +223,9 @@ class Model(list, metaclass=ModelMetaclass): for column in cls.__table__.columns: if column.name not in item: - item[column.name] = row[f'{table_prefix + "_" if table_prefix else ""}{column.name}'] + item[column.name] = row[ + f'{table_prefix + "_" if table_prefix else ""}{column.name}' + ] return cls(**item) @@ -200,7 +234,7 @@ class Model(list, metaclass=ModelMetaclass): # return cls.__pydantic_model__.validate(value=value) @classmethod - def __get_validators__(cls): # pragma no cover + def __get_validators__(cls) -> Callable: # pragma no cover yield cls.__pydantic_model__.validate # @classmethod @@ -211,11 +245,11 @@ class Model(list, metaclass=ModelMetaclass): return self.__model_fields__.get(column_name).__type__ == pydantic.Json @property - def pk(self): + def pk(self) -> str: return getattr(self.values, self.__pkname__) @pk.setter - def pk(self, value): + def pk(self, value: Any) -> None: setattr(self.values, self.__pkname__, value) @property @@ -229,7 +263,9 @@ class Model(list, metaclass=ModelMetaclass): if isinstance(nested_model, list): dict_instance[field] = [x.dict() for x in nested_model] else: - dict_instance[field] = nested_model.dict() if nested_model is not None else {} + dict_instance[field] = ( + nested_model.dict() if nested_model is not None else {} + ) return dict_instance def from_dict(self, value_dict: Dict) -> None: @@ -245,16 +281,22 @@ class Model(list, metaclass=ModelMetaclass): def extract_related_names(cls) -> Set: related_names = set() for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass(field.type_, pydantic.BaseModel): + if inspect.isclass(field.type_) and issubclass( + field.type_, pydantic.BaseModel + ): related_names.add(name) return related_names def extract_model_db_fields(self) -> Dict: self_fields = self.extract_own_model_fields() - self_fields = {k: v for k, v in self_fields.items() if k in self.__table__.columns} + self_fields = { + k: v for k, v in self_fields.items() if k in self.__table__.columns + } for field in self.extract_related_names(): if getattr(self, field) is not None: - self_fields[field] = getattr(getattr(self, field), self.__model_fields__[field].to.__pkname__) + self_fields[field] = getattr( + getattr(self, field), self.__model_fields__[field].to.__pkname__ + ) return self_fields async def save(self) -> int: @@ -264,7 +306,7 @@ class Model(list, metaclass=ModelMetaclass): expr = self.__table__.insert() expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) - setattr(self, 'pk', item_id) + self.pk = item_id return item_id async def update(self, **kwargs: Any) -> int: @@ -274,8 +316,11 @@ class Model(list, metaclass=ModelMetaclass): self_fields = self.extract_model_db_fields() self_fields.pop(self.__pkname__) - expr = self.__table__.update().values(**self_fields).where( - self.pk_column == getattr(self, self.__pkname__)) + expr = ( + self.__table__.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) + ) result = await self.__database__.execute(expr) return result @@ -285,7 +330,7 @@ class Model(list, metaclass=ModelMetaclass): result = await self.__database__.execute(expr) return result - async def load(self) -> 'Model': + async def load(self) -> "Model": expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) diff --git a/orm/queryset.py b/orm/queryset.py index 36160e8..39a187c 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,11 +1,14 @@ -from typing import List, TYPE_CHECKING, Type, NamedTuple +from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union -import sqlalchemy -from sqlalchemy import text +import databases import orm from orm import ForeignKey -from orm.exceptions import NoMatch, MultipleMatches +from orm.exceptions import MultipleMatches, NoMatch +from orm.fields import BaseField + +import sqlalchemy +from sqlalchemy import text if TYPE_CHECKING: # pragma no cover from orm.models import Model @@ -24,17 +27,23 @@ FILTER_OPERATORS = { class JoinParameters(NamedTuple): - prev_model: Type['Model'] + prev_model: Type["Model"] previous_alias: str from_table: str - model_cls: Type['Model'] + model_cls: Type["Model"] class QuerySet: - ESCAPE_CHARACTERS = ['%', '_'] + ESCAPE_CHARACTERS = ["%", "_"] - def __init__(self, model_cls: Type['Model'] = None, filter_clauses: List = None, select_related: List = None, - limit_count: int = None, offset: int = None): + def __init__( + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses self._select_related = [] if select_related is None else select_related @@ -48,47 +57,77 @@ class QuerySet: self.columns = None self.order_bys = None - def __get__(self, instance, owner): + def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": return self.__class__(model_cls=owner) @property - def database(self): + def database(self) -> databases.Database: return self.model_cls.__database__ @property - def table(self): + def table(self) -> sqlalchemy.Table: return self.model_cls.__table__ - def prefixed_columns(self, alias, table): - return [text(f'{alias}_{table.name}.{column.name} as {alias}_{column.name}') - for column in table.columns] + def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]: + return [ + text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") + for column in table.columns + ] - def prefixed_table_name(self, alias, name): - return text(f'{name} {alias}_{name}') + def prefixed_table_name(self, alias: str, name: str) -> text: + return text(f"{name} {alias}_{name}") - def on_clause(self, from_table, to_table, previous_alias, alias, to_key, from_key): - return text(f'{alias}_{to_table}.{to_key}=' - f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}') + def on_clause( + self, + from_table: str, + to_table: str, + previous_alias: str, + alias: str, + to_key: str, + from_key: str, + ) -> text: + return text( + f"{alias}_{to_table}.{to_key}=" + f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' + ) - def build_join_parameters(self, part, join_params: JoinParameters): + def build_join_parameters( + self, part: str, join_params: JoinParameters + ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to to_table = model_cls.__table__.name - alias = model_cls._orm_relationship_manager.resolve_relation_join(join_params.from_table, to_table) + alias = model_cls._orm_relationship_manager.resolve_relation_join( + join_params.from_table, to_table + ) if alias not in self.used_aliases: if join_params.prev_model.__model_fields__[part].virtual: - to_key = next((v for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model), None).name + to_key = next( + ( + v + for k, v in model_cls.__model_fields__.items() + if isinstance(v, ForeignKey) and v.to == join_params.prev_model + ), + None, + ).name from_key = model_cls.__pkname__ else: to_key = model_cls.__pkname__ from_key = part - on_clause = self.on_clause(join_params.from_table, to_table, join_params.previous_alias, alias, to_key, - from_key) + on_clause = self.on_clause( + join_params.from_table, + to_table, + join_params.previous_alias, + alias, + to_key, + from_key, + ) target_table = self.prefixed_table_name(alias, to_table) - self.select_from = sqlalchemy.sql.outerjoin(self.select_from, target_table, on_clause) - self.order_bys.append(text(f'{alias}_{to_table}.{model_cls.__pkname__}')) + self.select_from = sqlalchemy.sql.outerjoin( + self.select_from, target_table, on_clause + ) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) self.used_aliases.append(alias) @@ -98,44 +137,76 @@ class QuerySet: return JoinParameters(prev_model, previous_alias, from_table, model_cls) @staticmethod - def field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part) -> bool: + def field_is_a_foreign_key_and_no_circular_reference( + field: BaseField, field_name: str, rel_part: str + ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part - def field_qualifies_to_deeper_search(self, field, parent_virtual, nested, rel_part) -> bool: + def field_qualifies_to_deeper_search( + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any([x.startswith(prev_part_of_related) for x in self._select_related]) + partial_match = any( + [x.startswith(prev_part_of_related) for x in self._select_related] + ) already_checked = any([x.startswith(rel_part) for x in self.auto_related]) - return ((field.virtual and parent_virtual) or (partial_match and not already_checked)) or not nested + return ( + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested - def extract_auto_required_relations(self, join_params: JoinParameters, - rel_part: str = '', nested: bool = False, parent_virtual: bool = False): + def extract_auto_required_relations( + self, + join_params: JoinParameters, + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, + ) -> None: for field_name, field in join_params.prev_model.__model_fields__.items(): - if self.field_is_a_foreign_key_and_no_circular_reference(field, field_name, rel_part): - rel_part = field_name if not rel_part else rel_part + '__' + field_name + if self.field_is_a_foreign_key_and_no_circular_reference( + field, field_name, rel_part + ): + rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: if rel_part not in self._select_related: self.auto_related.append("__".join(rel_part.split("__")[:-1])) - rel_part = '' - elif self.field_qualifies_to_deeper_search(field, parent_virtual, nested, rel_part): - join_params = JoinParameters(field.to, join_params.previous_alias, - join_params.from_table, join_params.prev_model) - self.extract_auto_required_relations(join_params=join_params, - rel_part=rel_part, nested=True, parent_virtual=field.virtual) + rel_part = "" + elif self.field_qualifies_to_deeper_search( + field, parent_virtual, nested, rel_part + ): + join_params = JoinParameters( + field.to, + join_params.previous_alias, + join_params.from_table, + join_params.prev_model, + ) + self.extract_auto_required_relations( + join_params=join_params, + rel_part=rel_part, + nested=True, + parent_virtual=field.virtual, + ) else: - rel_part = '' + rel_part = "" - def build_select_expression(self): + def build_select_expression(self) -> sqlalchemy.sql.select: self.columns = list(self.table.columns) - self.order_bys = [text(f'{self.table.name}.{self.model_cls.__pkname__}')] + self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] self.select_from = self.table for key in self.model_cls.__model_fields__: - if not self.model_cls.__model_fields__[key].nullable \ - and isinstance(self.model_cls.__model_fields__[key], orm.fields.ForeignKey) \ - and key not in self._select_related: + if ( + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], orm.fields.ForeignKey + ) + and key not in self._select_related + ): self._select_related = [key] + self._select_related - start_params = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) + start_params = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) self.extract_auto_required_relations(start_params) if self.auto_related: new_joins = [] @@ -146,7 +217,9 @@ class QuerySet: self._select_related.sort(key=lambda item: (-len(item), item)) for item in self._select_related: - join_parameters = JoinParameters(self.model_cls, '', self.table.name, self.model_cls) + join_parameters = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) for part in item.split("__"): join_parameters = self.build_join_parameters(part, join_parameters) @@ -180,7 +253,7 @@ class QuerySet: return expr - def filter(self, **kwargs): + def filter(self, **kwargs: Any) -> "QuerySet": filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -189,7 +262,7 @@ class QuerySet: kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): - table_prefix = '' + table_prefix = "" if "__" in key: parts = key.split("__") @@ -215,9 +288,13 @@ class QuerySet: # against which the comparison is being made. previous_table = model_cls.__tablename__ for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - table_prefix = model_cls._orm_relationship_manager.resolve_relation_join(previous_table, - current_table) + current_table = model_cls.__model_fields__[ + part + ].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join( + previous_table, current_table + ) model_cls = model_cls.__model_fields__[part].to previous_table = current_table @@ -236,25 +313,32 @@ class QuerySet: has_escaped_character = False if op in ["contains", "icontains"]: - has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS - if c in value) + has_escaped_character = any( + c for c in self.ESCAPE_CHARACTERS if c in value + ) if has_escaped_character: # enable escape modifier for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f'\\{char}') + value = value.replace(char, f"\\{char}") value = f"%{value}%" if isinstance(value, orm.Model): value = value.pk clause = getattr(column, op_attr)(value) - clause.modifiers['escape'] = '\\' if has_escaped_character else None + clause.modifiers["escape"] = "\\" if has_escaped_character else None - clause_text = str(clause.compile(dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True})) - alias = f'{table_prefix}_' if table_prefix else '' - aliased_name = f'{alias}{table.name}.{column.name}' - clause_text = clause_text.replace(f'{table.name}.{column.name}', aliased_name) + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace( + f"{table.name}.{column.name}", aliased_name + ) clause = text(clause_text) filter_clauses.append(clause) @@ -264,10 +348,10 @@ class QuerySet: filter_clauses=filter_clauses, select_related=select_related, limit_count=self.limit_count, - offset=self.query_offset + offset=self.query_offset, ) - def select_related(self, related): + def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": if not isinstance(related, (list, tuple)): related = [related] @@ -277,7 +361,7 @@ class QuerySet: filter_clauses=self.filter_clauses, select_related=related, limit_count=self.limit_count, - offset=self.query_offset + offset=self.query_offset, ) async def exists(self) -> bool: @@ -290,25 +374,25 @@ class QuerySet: expr = sqlalchemy.func.count().select().select_from(expr) return await self.database.fetch_val(expr) - def limit(self, limit_count: int): + def limit(self, limit_count: int) -> "QuerySet": return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=limit_count, - offset=self.query_offset + offset=self.query_offset, ) - def offset(self, offset: int): + def offset(self, offset: int) -> "QuerySet": return self.__class__( model_cls=self.model_cls, filter_clauses=self.filter_clauses, select_related=self._select_related, limit_count=self.limit_count, - offset=offset + offset=offset, ) - async def first(self, **kwargs): + async def first(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).first() @@ -316,7 +400,7 @@ class QuerySet: if rows: return rows[0] - async def get(self, **kwargs): + async def get(self, **kwargs: Any) -> "Model": if kwargs: return await self.filter(**kwargs).get() @@ -329,7 +413,7 @@ class QuerySet: raise MultipleMatches() return self.model_cls.from_row(rows[0], select_related=self._select_related) - async def all(self, **kwargs): + async def all(self, **kwargs: Any) -> List["Model"]: if kwargs: return await self.filter(**kwargs).all() @@ -345,7 +429,7 @@ class QuerySet: return result_rows @classmethod - def merge_result_rows(cls, result_rows): + def merge_result_rows(cls, result_rows: List["Model"]) -> List["Model"]: merged_rows = [] for index, model in enumerate(result_rows): if index > 0 and model.pk == result_rows[index - 1].pk: @@ -355,30 +439,45 @@ class QuerySet: return merged_rows @classmethod - def merge_two_instances(cls, one: 'Model', other: 'Model'): + def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.__model_fields__.keys(): # print(field, one.dict(), other.dict()) - if isinstance(getattr(one, field), list) and not isinstance(getattr(one, field), orm.models.Model): + if isinstance(getattr(one, field), list) and not isinstance( + getattr(one, field), orm.models.Model + ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), orm.models.Model): if getattr(one, field).pk == getattr(other, field).pk: - setattr(other, field, cls.merge_two_instances(getattr(one, field), getattr(other, field))) + setattr( + other, + field, + cls.merge_two_instances( + getattr(one, field), getattr(other, field) + ), + ) return other - async def create(self, **kwargs): + async def create(self, **kwargs: Any) -> "Model": new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. pkname = self.model_cls.__pkname__ pk = self.model_cls.__model_fields__[pkname] - if pkname in new_kwargs and new_kwargs.get(pkname) is None and (pk.nullable or pk.autoincrement): + if ( + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) + ): del new_kwargs[pkname] # substitute related models with their pk for field in self.model_cls.extract_related_names(): if field in new_kwargs and new_kwargs.get(field) is not None: - new_kwargs[field] = getattr(new_kwargs.get(field), self.model_cls.__model_fields__[field].to.__pkname__) + new_kwargs[field] = getattr( + new_kwargs.get(field), + self.model_cls.__model_fields__[field].to.__pkname__, + ) # Build the insert expression. expr = self.table.insert() diff --git a/orm/relations.py b/orm/relations.py index a9a6971..b5741e1 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -2,7 +2,7 @@ import pprint import string import uuid from random import choices -from typing import TYPE_CHECKING, List +from typing import Dict, List, TYPE_CHECKING, Union from weakref import proxy from orm.fields import ForeignKey @@ -11,40 +11,58 @@ if TYPE_CHECKING: # pragma no cover from orm.models import Model -def get_table_alias(): - return ''.join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] +def get_table_alias() -> str: + return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] -def get_relation_config(relation_type: str, table_name: str, field: ForeignKey): +def get_relation_config( + relation_type: str, table_name: str, field: ForeignKey +) -> Dict[str, str]: alias = get_table_alias() - config = {'type': relation_type, - 'table_alias': alias, - 'source_table': table_name if relation_type == 'primary' else field.to.__tablename__, - 'target_table': field.to.__tablename__ if relation_type == 'primary' else table_name - } + config = { + "type": relation_type, + "table_alias": alias, + "source_table": table_name + if relation_type == "primary" + else field.to.__tablename__, + "target_table": field.to.__tablename__ + if relation_type == "primary" + else table_name, + } return config class RelationshipManager: - - def __init__(self): + def __init__(self) -> None: self._relations = dict() - def add_relation_type(self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str): - print(relations_key, reverse_key) + def add_relation_type( + self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + ) -> None: if relations_key not in self._relations: - self._relations[relations_key] = get_relation_config('primary', table_name, field) + self._relations[relations_key] = get_relation_config( + "primary", table_name, field + ) if reverse_key not in self._relations: - self._relations[reverse_key] = get_relation_config('reverse', table_name, field) + self._relations[reverse_key] = get_relation_config( + "reverse", table_name, field + ) - def deregister(self, model: 'Model'): + def deregister(self, model: "Model") -> None: # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: del self._relations[rel_type][model._orm_id] - def add_relation(self, parent_name: str, child_name: str, parent: 'Model', child: 'Model', virtual: bool = False): + def add_relation( + self, + parent_name: str, + child_name: str, + parent: "Model", + child: "Model", + virtual: bool = False, + ) -> None: parent_id = parent._orm_id child_id = child._orm_id if virtual: @@ -53,12 +71,18 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - parents_list = self._relations[parent_name.lower().title() + '_' + child_name + 's'].setdefault(parent_id, []) + parents_list = self._relations[ + parent_name.lower().title() + "_" + child_name + "s" + ].setdefault(parent_id, []) self.append_related_model(parents_list, child) - children_list = self._relations[child_name.lower().title() + '_' + parent_name].setdefault(child_id, []) + children_list = self._relations[ + child_name.lower().title() + "_" + parent_name + ].setdefault(child_id, []) self.append_related_model(children_list, parent) - def append_related_model(self, relations_list: List['Model'], model: 'Model'): + def append_related_model( + self, relations_list: List["Model"], model: "Model" + ) -> None: for x in relations_list: try: if x.__same__(model): @@ -68,26 +92,26 @@ class RelationshipManager: relations_list.append(model) - def contains(self, relations_key: str, object: 'Model'): + def contains(self, relations_key: str, object: "Model") -> bool: if relations_key in self._relations: return object._orm_id in self._relations[relations_key] return False - def get(self, relations_key: str, object: 'Model'): + def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]: if relations_key in self._relations: if object._orm_id in self._relations[relations_key]: - if self._relations[relations_key]['type'] == 'primary': + if self._relations[relations_key]["type"] == "primary": return self._relations[relations_key][object._orm_id][0] return self._relations[relations_key][object._orm_id] def resolve_relation_join(self, from_table: str, to_table: str) -> str: for k, v in self._relations.items(): - if v['source_table'] == from_table and v['target_table'] == to_table: - return self._relations[k]['table_alias'] - return '' + if v["source_table"] == from_table and v["target_table"] == to_table: + return self._relations[k]["table_alias"] + return "" - def __str__(self): # pragma no cover + def __str__(self) -> str: # pragma no cover return pprint.pformat(self._relations, indent=4, width=1) - def __repr__(self): # pragma no cover + def __repr__(self) -> str: # pragma no cover return self.__str__() diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index f06f141..1b17457 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -109,6 +109,22 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) +def test_decimal_error_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.Decimal(name='test12', primary_key=True) + + +def test_string_error_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example4" + __metadata__ = metadata + test = fields.String(name='test12', primary_key=True) + + def test_json_conversion_in_model(): with pytest.raises(pydantic.ValidationError): ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True) From fa00f7b011ae733c86ed87230c8b16c97e6ee2fd Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 07:53:06 +0200 Subject: [PATCH 31/62] fix coverage --- .coverage | Bin 53248 -> 53248 bytes .flake8 | 2 +- orm/models.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.coverage b/.coverage index dc4c4fa39b8c2f40a9c341058db4ebd6e14f7779..6b4bedb1fcbc2347fae84cb3ec780f161553b1dd 100644 GIT binary patch delta 260 zcmV+f0sH=dpaX!Q1F$SF1vN4{F*UO>FVheM4h?aW2S6RO5fDucvvEz60tyZdasUr| z4^Iy#vk*`@4UwQL12_Y9lYw3xvyxqOA|8AuN4-^jv4&n~V z4yg{2vk?$&4wI9Pp?^dN1px_x2nXsd|Kqd%Z@0T={l~Yz+TY##eqaCozEAI~{X5%j z8((*icRSwA9J_9NUj5vk_x}93yM6bm-){d@e)*UKFa`tx31S9*y_IkI{coGgv)Rr4 zn#2I~zV^}eeY0({xBYcbaIm}kt$V+{*MIM~?*k|W1OW+B1+%e_Q!oP#54w{Fz#X#@ K5S9;R=FVheM4-H6@2S6PH4M?*Q5F!l)Cl3PvRg-^D z8VYD-Y;a|Ab1rasvvEz60t^ogNdOOf4^Iy#4N0>OP&o|)C<9fKfnFC9Z*py1Xk~10 zWpZ;aaCr|n19g*uULCWNU34NIlK>C-59$xz57ZB;4|oqW4;~K|4+#$B4$Ka#4wJJH z5N{5Xk&dB%LI(u_34sI$y37CgtpD5X?%Ct^SNpqrTkq@N-}mW#wSQ;3ZR6|i@ovZ4 z%(3gX=he^sdGF7kyW4l4`t9~l<(H2+Fa`tx31S9*y_IkI{coGgv)Rr4nq&ZaU;F6# zzS%a}+y1&IIN06&*1g}}>%aHg_W>vc1OW+B1#T;|vyV|Q0}l_ilLx>Z0uQyb5fG0L N1p^NQ01cD6&qj92hF$;w diff --git a/.flake8 b/.flake8 index ec05a50..9976335 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] ignore = ANN101, ANN102, W503 -max-complexity = 10 +max-complexity = 8 max-line-length = 88 exclude = p38venv,.pytest_cache diff --git a/orm/models.py b/orm/models.py index 68e32f4..e0aaa25 100644 --- a/orm/models.py +++ b/orm/models.py @@ -184,7 +184,7 @@ class Model(list, metaclass=ModelMetaclass): return self.values.dict() == other.values.dict() def __same__(self, other: "Model") -> bool: - if self.__class__ != other.__class__: + if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( self.values is not None and other.values is not None and self.pk == other.pk From 22c4a0619c8ad834b81d56553b10b0185754cb7b Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 08:59:36 +0200 Subject: [PATCH 32/62] fix some code smells --- .coverage | Bin 53248 -> 53248 bytes README.md | 50 +++++++++--------- orm/queryset.py | 92 +++++++++++++++++++-------------- orm/relations.py | 44 +++++++++------- requirements.txt | 7 ++- tests/test_same_table_joins.py | 49 +++++++----------- 6 files changed, 127 insertions(+), 115 deletions(-) diff --git a/.coverage b/.coverage index 6b4bedb1fcbc2347fae84cb3ec780f161553b1dd..0de5a489869598a0ebded3ec7d3ded637646d989 100644 GIT binary patch delta 170 zcmV;b09F5hpaX!Q1F$MD2RJ%4IXW;fvoSB^P!BN=8V?c=1rFd2$PT9tk+TsHZ4MR` z4g>)SDh_V*=AAe1dC$-P=LgSs{*$SWb6yn)1OW*w2!6%91j!^2M*#6$_}Xxk+TsHZ4MX~ z4g>)SE)H(<=AAe1dC$*(2hV@+{O sqlalchemy.Table: return self.model_cls.__table__ - def prefixed_columns(self, alias: str, table: sqlalchemy.Table) -> List[text]: + @staticmethod + def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: return [ text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") for column in table.columns ] - def prefixed_table_name(self, alias: str, name: str) -> text: + @staticmethod + def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") def on_clause( @@ -91,7 +93,7 @@ class QuerySet: f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' ) - def build_join_parameters( + def _build_join_parameters( self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to @@ -137,12 +139,12 @@ class QuerySet: return JoinParameters(prev_model, previous_alias, from_table, model_cls) @staticmethod - def field_is_a_foreign_key_and_no_circular_reference( + def _field_is_a_foreign_key_and_no_circular_reference( field: BaseField, field_name: str, rel_part: str ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part - def field_qualifies_to_deeper_search( + def _field_qualifies_to_deeper_search( self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) @@ -155,7 +157,7 @@ class QuerySet: or (partial_match and not already_checked) ) or not nested - def extract_auto_required_relations( + def _extract_auto_required_relations( self, join_params: JoinParameters, rel_part: str = "", @@ -163,7 +165,7 @@ class QuerySet: parent_virtual: bool = False, ) -> None: for field_name, field in join_params.prev_model.__model_fields__.items(): - if self.field_is_a_foreign_key_and_no_circular_reference( + if self._field_is_a_foreign_key_and_no_circular_reference( field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name @@ -171,7 +173,7 @@ class QuerySet: if rel_part not in self._select_related: self.auto_related.append("__".join(rel_part.split("__")[:-1])) rel_part = "" - elif self.field_qualifies_to_deeper_search( + elif self._field_qualifies_to_deeper_search( field, parent_virtual, nested, rel_part ): join_params = JoinParameters( @@ -180,7 +182,7 @@ class QuerySet: join_params.from_table, join_params.prev_model, ) - self.extract_auto_required_relations( + self._extract_auto_required_relations( join_params=join_params, rel_part=rel_part, nested=True, @@ -189,6 +191,41 @@ class QuerySet: else: rel_part = "" + def _include_auto_related_models(self) -> None: + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + + def _apply_expression_modifiers( + self, expr: sqlalchemy.sql.select + ) -> sqlalchemy.sql.select: + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + for order in self.order_bys: + expr = expr.order_by(order) + return expr + + def _reset_query_parameters(self) -> None: + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] + def build_select_expression(self) -> sqlalchemy.sql.select: self.columns = list(self.table.columns) self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] @@ -207,14 +244,9 @@ class QuerySet: start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls ) - self.extract_auto_required_relations(start_params) - if self.auto_related: - new_joins = [] - for join in self._select_related: - if not any([x.startswith(join) for x in self.auto_related]): - new_joins.append(join) - self._select_related = new_joins + self.auto_related - self._select_related.sort(key=lambda item: (-len(item), item)) + self._extract_auto_required_relations(start_params) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) for item in self._select_related: join_parameters = JoinParameters( @@ -222,34 +254,15 @@ class QuerySet: ) for part in item.split("__"): - join_parameters = self.build_join_parameters(part, join_parameters) + join_parameters = self._build_join_parameters(part, join_parameters) expr = sqlalchemy.sql.select(self.columns) expr = expr.select_from(self.select_from) - if self.filter_clauses: - if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] - else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) - expr = expr.where(clause) - - if self.limit_count: - expr = expr.limit(self.limit_count) - - if self.query_offset: - expr = expr.offset(self.query_offset) - - for order in self.order_bys: - expr = expr.order_by(order) + expr = self._apply_expression_modifiers(expr) # print(expr.compile(compile_kwargs={"literal_binds": True})) - - self.select_from = None - self.columns = None - self.order_bys = None - self.auto_related = [] - self.used_aliases = [] + self._reset_query_parameters() return expr @@ -298,7 +311,6 @@ class QuerySet: model_cls = model_cls.__model_fields__[part].to previous_table = current_table - # print(table_prefix) table = model_cls.__table__ column = model_cls.__table__.columns[field_name] diff --git a/orm/relations.py b/orm/relations.py index b5741e1..3232c8c 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -71,43 +71,47 @@ class RelationshipManager: child, parent = parent, proxy(child) else: child = proxy(child) - parents_list = self._relations[ - parent_name.lower().title() + "_" + child_name + "s" - ].setdefault(parent_id, []) + + parent_relation_name = parent_name.lower().title() + "_" + child_name + "s" + parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) self.append_related_model(parents_list, child) - children_list = self._relations[ - child_name.lower().title() + "_" + parent_name - ].setdefault(child_id, []) + + child_relation_name = child_name.lower().title() + "_" + parent_name + children_list = self._relations[child_relation_name].setdefault(child_id, []) self.append_related_model(children_list, parent) - def append_related_model( - self, relations_list: List["Model"], model: "Model" - ) -> None: - for x in relations_list: + @staticmethod + def append_related_model(relations_list: List["Model"], model: "Model") -> None: + for relation_child in relations_list: try: - if x.__same__(model): + if relation_child.__same__(model): return except ReferenceError: continue relations_list.append(model) - def contains(self, relations_key: str, object: "Model") -> bool: + def contains(self, relations_key: str, instance: "Model") -> bool: if relations_key in self._relations: - return object._orm_id in self._relations[relations_key] + return instance._orm_id in self._relations[relations_key] return False - def get(self, relations_key: str, object: "Model") -> Union["Model", List["Model"]]: + def get( + self, relations_key: str, instance: "Model" + ) -> Union["Model", List["Model"]]: if relations_key in self._relations: - if object._orm_id in self._relations[relations_key]: + if instance._orm_id in self._relations[relations_key]: if self._relations[relations_key]["type"] == "primary": - return self._relations[relations_key][object._orm_id][0] - return self._relations[relations_key][object._orm_id] + return self._relations[relations_key][instance._orm_id][0] + return self._relations[relations_key][instance._orm_id] def resolve_relation_join(self, from_table: str, to_table: str) -> str: - for k, v in self._relations.items(): - if v["source_table"] == from_table and v["target_table"] == to_table: - return self._relations[k]["table_alias"] + for relation_name, relation in self._relations.items(): + if ( + relation["source_table"] == from_table + and relation["target_table"] == to_table + ): + return self._relations[relation_name]["table_alias"] return "" def __str__(self) -> str: # pragma no cover diff --git a/requirements.txt b/requirements.txt index fde4a73..807e704 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,9 @@ flake8-black flake8-bugbear flake8-import-order flake8-bandit -flake8-annotations \ No newline at end of file +flake8-annotations +flake8-builtins +flake8-variables-names +flake8-cognitive-complexity +flake8-functions +flake8-expression-complexity \ No newline at end of file diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index ca4e7b7..ea85409 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -67,16 +67,24 @@ def create_test_database(): metadata.drop_all(engine) -@pytest.mark.asyncio -async def test_model_multiple_instances_of_same_table_in_schema(): - async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) +@pytest.fixture() +async def init_relation(): + department = await Department.objects.create(id=1, name='Math Department') + class1 = await SchoolClass.objects.create(name="Math", department=department) + category = await Category.objects.create(name="Foreign") + category2 = await Category.objects.create(name="Domestic") + await Student.objects.create(name="Jane", category=category, schoolclass=class1) + await Student.objects.create(name="Jack", category=category2, schoolclass=class1) + await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) + yield + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.drop_all(engine) + metadata.create_all(engine) + +@pytest.mark.asyncio +async def test_model_multiple_instances_of_same_table_in_schema(init_relation): + async with database: classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' assert classes[0].students[0].name == 'Jane' @@ -92,18 +100,9 @@ async def test_model_multiple_instances_of_same_table_in_schema(): @pytest.mark.asyncio -async def test_right_tables_join(): +async def test_right_tables_join(init_relation): async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' assert classes[0].teachers[0].category.name == 'Domestic' assert classes[0].students[0].category.name is None @@ -112,17 +111,9 @@ async def test_right_tables_join(): @pytest.mark.asyncio -async def test_multiple_reverse_related_objects(): +async def test_multiple_reverse_related_objects(init_relation): async with database: - department = await Department.objects.create(id=1, name='Math Department') - class1 = await SchoolClass.objects.create(name="Math", department=department) - category = await Category.objects.create(name="Foreign") - category2 = await Category.objects.create(name="Domestic") - await Student.objects.create(name="Jane", category=category, schoolclass=class1) - await Student.objects.create(name="Jack", category=category, schoolclass=class1) - await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' + assert classes[0].students[1].name == 'Jack' assert classes[0].teachers[0].category.name == 'Domestic' From fb5d03d64c4047ce76f3131117116aae863e58a9 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 10:58:36 +0200 Subject: [PATCH 33/62] fix some complexity issues --- .coverage | Bin 53248 -> 53248 bytes .flake8 | 2 +- orm/exceptions.py | 4 + orm/models.py | 21 ++-- orm/queryset.py | 215 ++++++++++++++++++++------------- tests/test_fastapi_usage.py | 10 +- tests/test_foreign_keys.py | 30 ++++- tests/test_model_definition.py | 52 +++++--- tests/test_models.py | 12 +- tests/test_same_table_joins.py | 32 +++-- 10 files changed, 244 insertions(+), 134 deletions(-) diff --git a/.coverage b/.coverage index 0de5a489869598a0ebded3ec7d3ded637646d989..6db7a66c4bd5b360314f4d547cfd9b10619f1d46 100644 GIT binary patch delta 379 zcmV->0fhd5paX!Q1F$JC2r)1^H8?sjF|#o*)leRR01x>O>JQ!z(+{Z+b`L2J5)TRv z{0`I(x(=8QfwK`1T@HU04g>)SEDmn-=AAe1dC$-P=Nmj<=Q|(`1OW*`4esX6dw%{r z-?#64cjvd=+kBhP&u@PD<(uyT7z_jf2{H`2%m2^+_?+|K=H@wj_UzfS=MN?e1OW+7 z3-0F4d)~bB&d>jSJOB4@^RIXA+}nHpcjx?f@^|O$^8pMB1Oa~u8VcT;H}9O>Fa`+( z0SOKX3Ka+h0SPP!e#QK4x!>va|GN9v`?oCtO9ur334s;|0-(42kI(wQ-R?g3`}SA+ z?%wzN`uF$g@3-xK@2lO}Qg@Dj{d>Om)b;PWd*45E?Vo<`&wGFV+}*zW)Ni+cD!+Wp z0Wk&y0SROVe!VJ{Z~6WA=JITIbH64r0KKn$bba4!o9u0W-4h)A+ui-vz2DyJzxUhs Z0h6kaDhmk&1OW*Y1ZHMt2D9FeB0vNCz&HQ^ delta 342 zcmV-c0jd6gpaX!Q1F$JC2rxK0H90ykFtaf))leRh01x>O>JQ!z)DNl;cn>iT8V?c= z1rFd2$PT9tk+TsHZ4Q4F4g>)SDh_V*=AAe1dC$-P=LgSs{vQnl0SQ12Zu90nKYyO@ z+jqXZ^V{xizRk~Xe);8_?*SGJ1OW*y40f0QpZT2g-{$5yKR-V|KL;ZV1OW+13-0F4 zd)~bB&d>jSJOB4@^Y6X)_MZRU^WXg4dHZ||3IqWO7z*B+H}3+R{F4fgAAdy$1px_x z3J2;f|Kqd%Z@0V8^&j8lJwz1M&5x9 str: + name = cls.__name__ + if lower: + name = name.lower() + if title: + name = name.title() + return name + def is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -256,7 +263,7 @@ class Model(list, metaclass=ModelMetaclass): def pk_column(self) -> sqlalchemy.Column: return self.__table__.primary_key.columns.values()[0] - def dict(self) -> Dict: + def dict(self) -> Dict: # noqa: A003 dict_instance = self.values.dict() for field in self.extract_related_names(): nested_model = getattr(self, field) diff --git a/orm/queryset.py b/orm/queryset.py index d66d50a..1b5a5dc 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -1,10 +1,20 @@ -from typing import Any, List, NamedTuple, TYPE_CHECKING, Tuple, Type, Union +from typing import ( + Any, + Dict, + List, + NamedTuple, + Optional, + TYPE_CHECKING, + Tuple, + Type, + Union, +) import databases import orm from orm import ForeignKey -from orm.exceptions import MultipleMatches, NoMatch +from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError from orm.fields import BaseField import sqlalchemy @@ -80,18 +90,11 @@ class QuerySet: return text(f"{name} {alias}_{name}") def on_clause( - self, - from_table: str, - to_table: str, - previous_alias: str, - alias: str, - to_key: str, - from_key: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: - return text( - f"{alias}_{to_table}.{to_key}=" - f'{previous_alias + "_" if previous_alias else ""}{from_table}.{from_key}' - ) + left_part = f"{alias}_{to_clause}" + right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" + return text(f"{left_part}={right_part}") def _build_join_parameters( self, part: str, join_params: JoinParameters @@ -118,12 +121,10 @@ class QuerySet: from_key = part on_clause = self.on_clause( - join_params.from_table, - to_table, - join_params.previous_alias, - alias, - to_key, - from_key, + previous_alias=join_params.previous_alias, + alias=alias, + from_clause=f"{join_params.from_table}.{from_key}", + to_clause=f"{to_table}.{to_key}", ) target_table = self.prefixed_table_name(alias, to_table) self.select_from = sqlalchemy.sql.outerjoin( @@ -159,12 +160,12 @@ class QuerySet: def _extract_auto_required_relations( self, - join_params: JoinParameters, + prev_model: Type["Model"], rel_part: str = "", nested: bool = False, parent_virtual: bool = False, ) -> None: - for field_name, field in join_params.prev_model.__model_fields__.items(): + for field_name, field in prev_model.__model_fields__.items(): if self._field_is_a_foreign_key_and_no_circular_reference( field, field_name, rel_part ): @@ -176,14 +177,8 @@ class QuerySet: elif self._field_qualifies_to_deeper_search( field, parent_virtual, nested, rel_part ): - join_params = JoinParameters( - field.to, - join_params.previous_alias, - join_params.from_table, - join_params.prev_model, - ) self._extract_auto_required_relations( - join_params=join_params, + prev_model=field.to, rel_part=rel_part, nested=True, parent_virtual=field.virtual, @@ -244,7 +239,7 @@ class QuerySet: start_params = JoinParameters( self.model_cls, "", self.table.name, self.model_cls ) - self._extract_auto_required_relations(start_params) + self._extract_auto_required_relations(prev_model=start_params.prev_model) self._include_auto_related_models() self._select_related.sort(key=lambda item: (-len(item), item)) @@ -266,7 +261,90 @@ class QuerySet: return expr - def filter(self, **kwargs: Any) -> "QuerySet": + def _determine_filter_target_table( + self, related_parts: List[str], select_related: List[str] + ) -> Tuple[List[str], str, "Model"]: + + table_prefix = "" + model_cls = self.model_cls + select_related = [relation for relation in select_related] + + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + previous_table = model_cls.__tablename__ + for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join(previous_table, current_table) + model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + return select_related, table_prefix, model_cls + + def _compile_clause( + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) + clause = text(clause_text) + return clause + + def _escape_characters_in_clause( + self, op: str, value: Union[str, "Model"] + ) -> Tuple[str, bool]: + has_escaped_character = False + + if op in ["contains", "icontains"]: + if isinstance(value, orm.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + + has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in self.ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" + + return value, has_escaped_character + + @staticmethod + def _extract_operator_field_and_related( + parts: List[str], + ) -> Tuple[str, str, Optional[List]]: + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + return op, field_name, related_parts + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -279,37 +357,21 @@ class QuerySet: if "__" in key: parts = key.split("__") - # Determine if we should treat the final part as a - # filter operator or as a related field. - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] + ( + op, + field_name, + related_parts, + ) = self._extract_operator_field_and_related(parts) model_cls = self.model_cls if related_parts: - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - previous_table = model_cls.__tablename__ - for part in related_parts: - current_table = model_cls.__model_fields__[ - part - ].to.__tablename__ - manager = model_cls._orm_relationship_manager - table_prefix = manager.resolve_relation_join( - previous_table, current_table - ) - model_cls = model_cls.__model_fields__[part].to - previous_table = current_table + ( + select_related, + table_prefix, + model_cls, + ) = self._determine_filter_target_table( + related_parts, select_related + ) table = model_cls.__table__ column = model_cls.__table__.columns[field_name] @@ -319,39 +381,20 @@ class QuerySet: column = self.table.columns[key] table = self.table - # Map the operation code onto SQLAlchemy's ColumnElement - # https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement - op_attr = FILTER_OPERATORS[op] - has_escaped_character = False - - if op in ["contains", "icontains"]: - has_escaped_character = any( - c for c in self.ESCAPE_CHARACTERS if c in value - ) - if has_escaped_character: - # enable escape modifier - for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - value = f"%{value}%" + value, has_escaped_character = self._escape_characters_in_clause(op, value) if isinstance(value, orm.Model): value = value.pk + op_attr = FILTER_OPERATORS[op] clause = getattr(column, op_attr)(value) - clause.modifiers["escape"] = "\\" if has_escaped_character else None - - clause_text = str( - clause.compile( - dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True}, - ) + clause = self._compile_clause( + clause, + column, + table, + table_prefix, + modifiers={"escape": "\\" if has_escaped_character else None}, ) - alias = f"{table_prefix}_" if table_prefix else "" - aliased_name = f"{alias}{table.name}.{column.name}" - clause_text = clause_text.replace( - f"{table.name}.{column.name}", aliased_name - ) - clause = text(clause_text) filter_clauses.append(clause) @@ -425,7 +468,7 @@ class QuerySet: raise MultipleMatches() return self.model_cls.from_row(rows[0], select_related=self._select_related) - async def all(self, **kwargs: Any) -> List["Model"]: + async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 if kwargs: return await self.filter(**kwargs).all() diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 8889064..1c3d5dd 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -40,8 +40,14 @@ client = TestClient(app) def test_read_main(): - response = client.post("/items/", json={'name': 'test', 'id': 1, 'category': {'name': 'test cat'}}) + response = client.post( + "/items/", json={"name": "test", "id": 1, "category": {"name": "test cat"}} + ) assert response.status_code == 200 - assert response.json() == {'category': {'id': None, 'name': 'test cat'}, 'id': 1, 'name': 'test'} + assert response.json() == { + "category": {"id": None, "name": "test cat"}, + "id": 1, + "name": "test", + } item = Item(**response.json()) assert item.id == 1 diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index c222cfa..dfba2da 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -88,7 +88,7 @@ async def test_model_crud(): assert len(album.tracks) == 3 assert album.tracks[1].title == "Heart don't stand a chance" - album1 = await Album.objects.get(name='Malibu') + album1 = await Album.objects.get(name="Malibu") assert album1.pk == 1 assert album1.tracks is None @@ -127,7 +127,9 @@ async def test_fk_filter(): malibu = Album(name="Malibu%") await malibu.save() await Track.objects.create(album=malibu, title="The Bird", position=1) - await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) + await Track.objects.create( + album=malibu, title="Heart don't stand a chance", position=2 + ) await Track.objects.create(album=malibu, title="The Waters", position=3) fantasies = await Album.objects.create(name="Fantasies") @@ -135,12 +137,20 @@ async def test_fk_filter(): await Track.objects.create(album=fantasies, title="Sick Muse", position=2) await Track.objects.create(album=fantasies, title="Satellite Mind", position=3) - tracks = await Track.objects.select_related("album").filter(album__name="Fantasies").all() + tracks = ( + await Track.objects.select_related("album") + .filter(album__name="Fantasies") + .all() + ) assert len(tracks) == 3 for track in tracks: assert track.album.name == "Fantasies" - tracks = await Track.objects.select_related("album").filter(album__name__icontains="fan").all() + tracks = ( + await Track.objects.select_related("album") + .filter(album__name__icontains="fan") + .all() + ) assert len(tracks) == 3 for track in tracks: assert track.album.name == "Fantasies" @@ -179,7 +189,11 @@ async def test_multiple_fk(): team = await Team.objects.create(org=other, name="Green Team") await Member.objects.create(team=team, email="e@example.org") - members = await Member.objects.select_related('team__org').filter(team__org__ident="ACME Ltd").all() + members = ( + await Member.objects.select_related("team__org") + .filter(team__org__ident="ACME Ltd") + .all() + ) assert len(members) == 4 for member in members: assert member.team.org.ident == "ACME Ltd" @@ -195,7 +209,11 @@ async def test_pk_filter(): tracks = await Track.objects.select_related("album").filter(pk=1).all() assert len(tracks) == 1 - tracks = await Track.objects.select_related("album").filter(position=2, album__name='Test').all() + tracks = ( + await Track.objects.select_related("album") + .filter(position=2, album__name="Test") + .all() + ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index 1b17457..e7f8e0b 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -1,5 +1,4 @@ import datetime -from typing import ClassVar import pydantic import pytest @@ -17,7 +16,7 @@ class ExampleModel(Model): __metadata__ = metadata test = fields.Integer(primary_key=True) test_string = fields.String(length=250) - test_text = fields.Text(default='') + test_text = fields.Text(default="") test_bool = fields.Boolean(nullable=False) test_float = fields.Float() test_datetime = fields.DateTime(default=datetime.datetime.now) @@ -28,33 +27,42 @@ class ExampleModel(Model): test_decimal = fields.Decimal(length=10, precision=2) -fields_to_check = ['test', 'test_text', 'test_string', 'test_datetime', 'test_date', 'test_text', 'test_float', - 'test_bigint', 'test_json'] +fields_to_check = [ + "test", + "test_text", + "test_string", + "test_datetime", + "test_date", + "test_text", + "test_float", + "test_bigint", + "test_json", +] class ExampleModel2(Model): __tablename__ = "example2" __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True) - test_string = fields.String('test_string2', length=250) + test = fields.Integer(name="test12", primary_key=True) + test_string = fields.String("test_string2", length=250) @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string='test', test_bool=True) + return ExampleModel(pk=1, test_string="test", test_bool=True) def test_not_nullable_field_is_required(): with pytest.raises(pydantic.error_wrappers.ValidationError): - ExampleModel(test=1, test_string='test') + ExampleModel(test=1, test_string="test") def test_model_attribute_access(example): assert example.test == 1 - assert example.test_string == 'test' + assert example.test_string == "test" assert example.test_datetime.year == datetime.datetime.now().year assert example.test_date == datetime.date.today() - assert example.test_text == '' + assert example.test_text == "" assert example.test_float is None assert example.test_bigint == 0 assert example.test_json == {} @@ -63,7 +71,7 @@ def test_model_attribute_access(example): assert example.test == 12 example.new_attr = 12 - assert 'new_attr' in example.__dict__ + assert "new_attr" in example.__dict__ def test_primary_key_access_and_setting(example): @@ -87,44 +95,54 @@ def test_sqlalchemy_table_is_created(example): def test_double_column_name_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata - test_string = fields.String('test_string2', name='test_string2', length=250) + test_string = fields.String("test_string2", name="test_string2", length=250) def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata - test_string = fields.String(name='test_string2', length=250) + test_string = fields.String(name="test_string2", length=250) def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.Integer(name='test12', primary_key=True, pydantic_only=True) + test = fields.Integer(name="test12", primary_key=True, pydantic_only=True) def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.Decimal(name='test12', primary_key=True) + test = fields.Decimal(name="test12", primary_key=True) def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.String(name='test12', primary_key=True) + test = fields.String(name="test12", primary_key=True) def test_json_conversion_in_model(): with pytest.raises(pydantic.ValidationError): - ExampleModel(test_json=datetime.datetime.now(), test=1, test_string='test', test_bool=True) + ExampleModel( + test_json=datetime.datetime.now(), + test=1, + test_string="test", + test_bool=True, + ) diff --git a/tests/test_models.py b/tests/test_models.py index cf12c3e..8b0bf37 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,6 +3,7 @@ import pytest import sqlalchemy import orm +from orm.exceptions import QueryDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -139,6 +140,13 @@ async def test_model_filter(): assert await products.count() == 3 +@pytest.mark.asyncio +async def test_wrong_query_contains_model(): + with pytest.raises(QueryDefinitionError): + product = Product(name="90%-Cotton", rating=2) + await Product.objects.filter(name__contains=product).count() + + @pytest.mark.asyncio async def test_model_exists(): async with database: @@ -175,7 +183,7 @@ async def test_model_limit_with_filter(): await User.objects.create(name="Tom") await User.objects.create(name="Tom") - assert len(await User.objects.limit(2).filter(name__iexact='Tom').all()) == 2 + assert len(await User.objects.limit(2).filter(name__iexact="Tom").all()) == 2 @pytest.mark.asyncio @@ -185,7 +193,7 @@ async def test_offset(): await User.objects.create(name="Jane") users = await User.objects.offset(1).limit(1).all() - assert users[0].name == 'Jane' + assert users[0].name == "Jane" @pytest.mark.asyncio diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index ea85409..66f6ad1 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -69,7 +69,7 @@ def create_test_database(): @pytest.fixture() async def init_relation(): - department = await Department.objects.create(id=1, name='Math Department') + department = await Department.objects.create(id=1, name="Math Department") class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") category2 = await Category.objects.create(name="Domestic") @@ -85,35 +85,41 @@ async def init_relation(): @pytest.mark.asyncio async def test_model_multiple_instances_of_same_table_in_schema(init_relation): async with database: - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].name == 'Math' - assert classes[0].students[0].name == 'Jane' + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[0].name == "Jane" # related fields of main model are only populated by pk # unless there is a required foreign key somewhere along the way # since department is required for schoolclass it was pre loaded (again) # but you can load them anytime - assert classes[0].students[0].schoolclass.name == 'Math' + assert classes[0].students[0].schoolclass.name == "Math" assert classes[0].students[0].schoolclass.department.name is None await classes[0].students[0].schoolclass.department.load() - assert classes[0].students[0].schoolclass.department.name == 'Math Department' + assert classes[0].students[0].schoolclass.department.name == "Math Department" @pytest.mark.asyncio async def test_right_tables_join(init_relation): async with database: - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].teachers[0].category.name == 'Domestic' + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].teachers[0].category.name == "Domestic" assert classes[0].students[0].category.name is None await classes[0].students[0].category.load() - assert classes[0].students[0].category.name == 'Foreign' + assert classes[0].students[0].category.name == "Foreign" @pytest.mark.asyncio async def test_multiple_reverse_related_objects(init_relation): async with database: - classes = await SchoolClass.objects.select_related(['teachers__category', 'students']).all() - assert classes[0].name == 'Math' - assert classes[0].students[1].name == 'Jack' - assert classes[0].teachers[0].category.name == 'Domestic' + classes = await SchoolClass.objects.select_related( + ["teachers__category", "students"] + ).all() + assert classes[0].name == "Math" + assert classes[0].students[1].name == "Jack" + assert classes[0].teachers[0].category.name == "Domestic" From d9755234c1c8ff9fb780326fe6fc6e9a0ac484d2 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 11:08:28 +0200 Subject: [PATCH 34/62] readme formatting --- README.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 69758d9..15272f9 100644 --- a/README.md +++ b/README.md @@ -171,17 +171,17 @@ All fields are required unless one of the following is set: Autoincrement is set by default on int primary keys. Available Model Fields: - * `orm.String(length)` - * `orm.Text()` - * `orm.Boolean()` - * `orm.Integer()` - * `orm.Float()` - * `orm.Date()` - * `orm.Time()` - * `orm.DateTime()` - * `orm.JSON()` - * `orm.BigInteger()` - * `orm.Decimal(lenght, precision)` +* `orm.String(length)` +* `orm.Text()` +* `orm.Boolean()` +* `orm.Integer()` +* `orm.Float()` +* `orm.Date()` +* `orm.Time()` +* `orm.DateTime()` +* `orm.JSON()` +* `orm.BigInteger()` +* `orm.Decimal(lenght, precision)` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases From 3f2568b27e90d4b572e10ad286d5468b6f1ecc4e Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 12:04:44 +0200 Subject: [PATCH 35/62] refactors in fields --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 55 ++++++++---- orm/models.py | 179 +++++++++++++++++++------------------ orm/queryset.py | 2 +- orm/relations.py | 12 +-- tests/test_foreign_keys.py | 6 ++ 6 files changed, 140 insertions(+), 114 deletions(-) diff --git a/.coverage b/.coverage index 6db7a66c4bd5b360314f4d547cfd9b10619f1d46..ecaa279aee7463d6f69f98fe28f82371df0483f0 100644 GIT binary patch delta 179 zcmV;k08IaYpaX!Q1F$MD2QoS^GdeRcvoSB#P#%8(5BU%358e;c52p`w4<`>04+swW z4$=;|4wepovk?$m4wHe7V;K9tZ|DF1ZT|JnoqK!F|L&asPX6w^eLj=?j)qe4`=86R z+0Ff$!~lF>`{??<**3X%aHg_W>vc1OW+B1#T<$ujh8pyYK%e fo4w6uv)ODmo6TnL=AZ0;clXV`{@gdS(vL1c8-`o8 diff --git a/orm/fields.py b/orm/fields.py index 9c9f3f7..eece9ea 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,6 +1,6 @@ import datetime import decimal -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, Union import orm from orm.exceptions import ModelDefinitionError, RelationshipInstanceError @@ -29,6 +29,9 @@ class BaseField: name = args.pop(0) self.name = name + self._populate_from_kwargs(kwargs) + + def _populate_from_kwargs(self, kwargs: Dict) -> None: self.primary_key = kwargs.pop("primary_key", False) self.autoincrement = kwargs.pop( "autoincrement", self.primary_key and self.__type__ == int @@ -79,7 +82,7 @@ class BaseField: index=self.index, unique=self.unique, default=self.default, - server_default=self.server_default + server_default=self.server_default, ) def get_column_type(self) -> sqlalchemy.types.TypeEngine: @@ -228,13 +231,13 @@ class ForeignKey(BaseField): def expand_relationship( self, value: Any, child: "Model" ) -> Union["Model", List["Model"]]: - if not isinstance(value, (self.to, dict, int, str, list)) or ( - isinstance(value, orm.models.Model) and not isinstance(value, self.to) - ): + + if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( - "Relationship model can be build only from orm.Model, " - "dict and integer or string (pk)." + f"Relationship error - expecting: {self.to.__name__}, " + f"but {value.__class__.__name__} encountered." ) + if isinstance(value, list) and not isinstance(value, self.to): model = [self.expand_relationship(val, child) for val in value] return model @@ -244,9 +247,19 @@ class ForeignKey(BaseField): elif isinstance(value, dict): model = self.to(**value) else: + if not isinstance(value, self.to.pk_type()): + raise RelationshipInstanceError( + f"Relationship error - ForeignKey {self.to.__name__} is of type {self.to.pk_type()} " + f"of type {self.__type__} while {type(value)} passed as a parameter." + ) model = create_dummy_instance(fk=self.to, pk=value) - child_model_name = self.related_name or child.__class__.__name__.lower() + "s" + self.add_to_relationship_registry(model, child) + + return model + + def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: + child_model_name = self.related_name or child.get_name() + "s" model._orm_relationship_manager.add_relation( model.__class__.__name__.lower(), child.__class__.__name__.lower(), @@ -257,16 +270,20 @@ class ForeignKey(BaseField): if ( child_model_name not in model.__fields__ - and child.__class__.__name__.lower() not in model.__fields__ + and child.get_name() not in model.__fields__ ): - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( - child.__class__, name=child_model_name, virtual=True - ) + self.register_reverse_model_fields(model, child, child_model_name) - return model + @staticmethod + def register_reverse_model_fields( + model: "Model", child: "Model", child_model_name: str + ) -> None: + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child.__class__, name=child_model_name, virtual=True + ) diff --git a/orm/models.py b/orm/models.py index c16b21d..6dc7fa3 100644 --- a/orm/models.py +++ b/orm/models.py @@ -32,8 +32,17 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple return pydantic_fields +def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: + child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s" + reverse_name = field.related_name or child_relation_name + relation_name = name.lower().title() + "_" + field.to.get_name() + relationship_manager.add_relation_type( + relation_name, reverse_name, field, table_name + ) + + def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, tablename: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] @@ -46,14 +55,7 @@ def sqlalchemy_columns_from_model_fields( if field.primary_key: pkname = field_name if isinstance(field, ForeignKey): - child_relation_name = ( - field.to.get_name(title=True) + "_" + name.lower() + "s" - ) - reverse_name = field.related_name or child_relation_name - relation_name = name.lower().title() + "_" + field.to.get_name() - relationship_manager.add_relation_type( - relation_name, reverse_name, field, tablename - ) + register_relation_on_build(table_name, field, name) columns.append(field.get_column(field_name)) return pkname, columns, model_fields @@ -109,8 +111,8 @@ class ModelMetaclass(type): return new_model -class Model(list, metaclass=ModelMetaclass): - # Model inherits from list in order to be treated as +class FakePydantic(list, metaclass=ModelMetaclass): + # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts __abstract__ = True @@ -125,9 +127,8 @@ class Model(list, metaclass=ModelMetaclass): __database__: databases.Database _orm_relationship_manager: RelationshipManager - objects = qry.QuerySet() - def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() self._orm_id: str = uuid.uuid4().hex self._orm_saved: bool = False self.values: Optional[BaseModel] = None @@ -145,7 +146,7 @@ class Model(list, metaclass=ModelMetaclass): def __setattr__(self, key: str, value: Any) -> None: if key in self.__fields__: - if self.is_conversion_to_json_needed(key) and not isinstance(value, str): + if self._is_conversion_to_json_needed(key) and not isinstance(value, str): try: value = json.dumps(value) except TypeError: # pragma no cover @@ -168,7 +169,7 @@ class Model(list, metaclass=ModelMetaclass): item = getattr(self.values, key, None) if ( item is not None - and self.is_conversion_to_json_needed(key) + and self._is_conversion_to_json_needed(key) and isinstance(item, str) ): try: @@ -191,6 +192,79 @@ class Model(list, metaclass=ModelMetaclass): def __repr__(self) -> str: # pragma no cover return self.values.__repr__() + @classmethod + def __get_validators__(cls) -> Callable: # pragma no cover + yield cls.__pydantic_model__.validate + + @classmethod + def get_name(cls, title: bool = False, lower: bool = True) -> str: + name = cls.__name__ + if lower: + name = name.lower() + if title: + name = name.title() + return name + + @property + def pk_column(self) -> sqlalchemy.Column: + return self.__table__.primary_key.columns.values()[0] + + @classmethod + def pk_type(cls): + return cls.__model_fields__[cls.__pkname__].__type__ + + def dict(self) -> Dict: # noqa: A003 + dict_instance = self.values.dict() + for field in self._extract_related_names(): + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = [x.dict() for x in nested_model] + else: + dict_instance[field] = ( + nested_model.dict() if nested_model is not None else {} + ) + return dict_instance + + def from_dict(self, value_dict: Dict) -> None: + for key, value in value_dict.items(): + setattr(self, key, value) + + def _is_conversion_to_json_needed(self, column_name: str) -> bool: + return self.__model_fields__.get(column_name).__type__ == pydantic.Json + + def _extract_own_model_fields(self) -> Dict: + related_names = self._extract_related_names() + self_fields = {k: v for k, v in self.dict().items() if k not in related_names} + return self_fields + + @classmethod + def _extract_related_names(cls) -> Set: + related_names = set() + for name, field in cls.__fields__.items(): + if inspect.isclass(field.type_) and issubclass( + field.type_, pydantic.BaseModel + ): + related_names.add(name) + return related_names + + def _extract_model_db_fields(self) -> Dict: + self_fields = self._extract_own_model_fields() + self_fields = { + k: v for k, v in self_fields.items() if k in self.__table__.columns + } + for field in self._extract_related_names(): + if getattr(self, field) is not None: + self_fields[field] = getattr( + getattr(self, field), self.__model_fields__[field].to.__pkname__ + ) + return self_fields + + +class Model(FakePydantic): + __abstract__ = True + + objects = qry.QuerySet() + @classmethod def from_row( cls, @@ -227,30 +301,6 @@ class Model(list, metaclass=ModelMetaclass): return cls(**item) - # @classmethod - # def validate(cls, value: Any) -> 'BaseModel': # pragma no cover - # return cls.__pydantic_model__.validate(value=value) - - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate - - # @classmethod - # def schema(cls, by_alias: bool = True): # pragma no cover - # return cls.__pydantic_model__.schema(by_alias=by_alias) - - @classmethod - def get_name(cls, title: bool = False, lower: bool = True) -> str: - name = cls.__name__ - if lower: - name = name.lower() - if title: - name = name.title() - return name - - def is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json - @property def pk(self) -> str: return getattr(self.values, self.__pkname__) @@ -259,55 +309,8 @@ class Model(list, metaclass=ModelMetaclass): def pk(self, value: Any) -> None: setattr(self.values, self.__pkname__, value) - @property - def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] - - def dict(self) -> Dict: # noqa: A003 - dict_instance = self.values.dict() - for field in self.extract_related_names(): - nested_model = getattr(self, field) - if isinstance(nested_model, list): - dict_instance[field] = [x.dict() for x in nested_model] - else: - dict_instance[field] = ( - nested_model.dict() if nested_model is not None else {} - ) - return dict_instance - - def from_dict(self, value_dict: Dict) -> None: - for key, value in value_dict.items(): - setattr(self, key, value) - - def extract_own_model_fields(self) -> Dict: - related_names = self.extract_related_names() - self_fields = {k: v for k, v in self.dict().items() if k not in related_names} - return self_fields - - @classmethod - def extract_related_names(cls) -> Set: - related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel - ): - related_names.add(name) - return related_names - - def extract_model_db_fields(self) -> Dict: - self_fields = self.extract_own_model_fields() - self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns - } - for field in self.extract_related_names(): - if getattr(self, field) is not None: - self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ - ) - return self_fields - async def save(self) -> int: - self_fields = self.extract_model_db_fields() + self_fields = self._extract_model_db_fields() if self.__model_fields__.get(self.__pkname__).autoincrement: self_fields.pop(self.__pkname__, None) expr = self.__table__.insert() @@ -321,7 +324,7 @@ class Model(list, metaclass=ModelMetaclass): new_values = {**self.dict(), **kwargs} self.from_dict(new_values) - self_fields = self.extract_model_db_fields() + self_fields = self._extract_model_db_fields() self_fields.pop(self.__pkname__) expr = ( self.__table__.update() diff --git a/orm/queryset.py b/orm/queryset.py index 1b5a5dc..23e4a76 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -527,7 +527,7 @@ class QuerySet: del new_kwargs[pkname] # substitute related models with their pk - for field in self.model_cls.extract_related_names(): + for field in self.model_cls._extract_related_names(): if field in new_kwargs and new_kwargs.get(field) is not None: new_kwargs[field] = getattr( new_kwargs.get(field), diff --git a/orm/relations.py b/orm/relations.py index 3232c8c..f541dfe 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -8,7 +8,7 @@ from weakref import proxy from orm.fields import ForeignKey if TYPE_CHECKING: # pragma no cover - from orm.models import Model + from orm.models import FakePydantic, Model def get_table_alias() -> str: @@ -48,7 +48,7 @@ class RelationshipManager: "reverse", table_name, field ) - def deregister(self, model: "Model") -> None: + def deregister(self, model: "FakePydantic") -> None: # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): if model.__class__.__name__.lower() in rel_type.lower(): @@ -59,8 +59,8 @@ class RelationshipManager: self, parent_name: str, child_name: str, - parent: "Model", - child: "Model", + parent: "FakePydantic", + child: "FakePydantic", virtual: bool = False, ) -> None: parent_id = parent._orm_id @@ -91,13 +91,13 @@ class RelationshipManager: relations_list.append(model) - def contains(self, relations_key: str, instance: "Model") -> bool: + def contains(self, relations_key: str, instance: "FakePydantic") -> bool: if relations_key in self._relations: return instance._orm_id in self._relations[relations_key] return False def get( - self, relations_key: str, instance: "Model" + self, relations_key: str, instance: "FakePydantic" ) -> Union["Model", List["Model"]]: if relations_key in self._relations: if instance._orm_id in self._relations[relations_key]: diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index dfba2da..660f3d4 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -67,6 +67,12 @@ def create_test_database(): metadata.drop_all(engine) +@pytest.mark.asyncio +async def test_wrong_query_foreign_key_type(): + with pytest.raises(RelationshipInstanceError): + Track(title="The Error", album="wrong_pk_type") + + @pytest.mark.asyncio async def test_model_crud(): async with database: From 836836c1367680d85c608552ca8f00b088620ced Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 12:53:28 +0200 Subject: [PATCH 36/62] refactor merging of instances from queryset to fakepydantic --- .coverage | Bin 53248 -> 53248 bytes README.md | 18 ++++++++++++++---- orm/fields.py | 6 ++++-- orm/models.py | 31 ++++++++++++++++++++++++++++++- orm/queryset.py | 31 +------------------------------ 5 files changed, 49 insertions(+), 37 deletions(-) diff --git a/.coverage b/.coverage index ecaa279aee7463d6f69f98fe28f82371df0483f0..09e8343a34172aa3ca53982806eb1a61d15636a3 100644 GIT binary patch delta 117 zcmV-*0E+*BpaX!Q1F$SF1vN4{H8-;{FXB)EaI+ME&JGy!f8Wmk{oDNOojdpTp8wrB z|DF8ZdHZ~m`i?CJM+XG~34soiE{{D6RQcs&4mAb@0SRoAdXFIx*KfCfvi7#W?&%RfB(~m7cVu&`0 diff --git a/README.md b/README.md index 15272f9..3186786 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ CodeFactor + +Codacy +

The `async-orm` package is an async ORM for Python, with support for Postgres, @@ -26,7 +29,7 @@ The goal was to create a simple orm that can be used directly with [`fastapi`][f Initial work was inspired by [`encode/orm`][encode/orm]. The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. -**aysn-orm is still under development:** We recommend pinning any dependencies with `aorm~=0.0.1` +**async-orm is still under development:** We recommend pinning any dependencies with `aorm~=0.0.1` **Note**: Use `ipython` to try this from the console, since it supports `await`. @@ -44,8 +47,9 @@ class Note(orm.Model): __database__ = database __metadata__ = metadata + # primary keys of type int by dafault are set to autoincrement id = orm.Integer(primary_key=True) - text = orm.String(max_length=100) + text = orm.String(length=100) completed = orm.Boolean(default=False) # Create the database @@ -97,7 +101,7 @@ class Album(orm.Model): __database__ = database id = orm.Integer(primary_key=True) - name = orm.String(max_length=100) + name = orm.String(length=100) class Track(orm.Model): @@ -107,7 +111,7 @@ class Track(orm.Model): id = orm.Integer(primary_key=True) album = orm.ForeignKey(Album) - title = orm.String(max_length=100) + title = orm.String(length=100) position = orm.Integer() @@ -138,6 +142,12 @@ assert track.album.name == "Malibu" track = await Track.objects.select_related("album").get(title="The Bird") assert track.album.name == "Malibu" +# By default you also get a second side of the relation +# constructed as lowercase source model name +'s' (tracks in this case) +# you can also provide custom name with parameter related_name +album = await Album.objects.select_related("tracks").all() +assert len(album.tracks) == 3 + # Fetch instances, with a filter across an FK relationship. tracks = Track.objects.filter(album__name="Fantasies") assert len(tracks) == 2 diff --git a/orm/fields.py b/orm/fields.py index eece9ea..f101346 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -249,8 +249,10 @@ class ForeignKey(BaseField): else: if not isinstance(value, self.to.pk_type()): raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} is of type {self.to.pk_type()} " - f"of type {self.__type__} while {type(value)} passed as a parameter." + f"Relationship error - ForeignKey {self.to.__name__} " + f"is of type {self.to.pk_type()} " + f"of type {self.__type__} " + f"while {type(value)} passed as a parameter." ) model = create_dummy_instance(fk=self.to, pk=value) diff --git a/orm/models.py b/orm/models.py index 6dc7fa3..d41b09c 100644 --- a/orm/models.py +++ b/orm/models.py @@ -210,7 +210,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): return self.__table__.primary_key.columns.values()[0] @classmethod - def pk_type(cls): + def pk_type(cls) -> Any: return cls.__model_fields__[cls.__pkname__].__type__ def dict(self) -> Dict: # noqa: A003 @@ -259,6 +259,35 @@ class FakePydantic(list, metaclass=ModelMetaclass): ) return self_fields + @classmethod + def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: + merged_rows = [] + for index, model in enumerate(result_rows): + if index > 0 and model.pk == result_rows[index - 1].pk: + result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + else: + merged_rows.append(model) + return merged_rows + + @classmethod + def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": + for field in one.__model_fields__.keys(): + # print(field, one.dict(), other.dict()) + if isinstance(getattr(one, field), list) and not isinstance( + getattr(one, field), Model + ): + setattr(other, field, getattr(one, field) + getattr(other, field)) + elif isinstance(getattr(one, field), Model): + if getattr(one, field).pk == getattr(other, field).pk: + setattr( + other, + field, + cls.merge_two_instances( + getattr(one, field), getattr(other, field) + ), + ) + return other + class Model(FakePydantic): __abstract__ = True diff --git a/orm/queryset.py b/orm/queryset.py index 23e4a76..727f9f7 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -479,39 +479,10 @@ class QuerySet: for row in rows ] - result_rows = self.merge_result_rows(result_rows) + result_rows = self.model_cls.merge_instances_list(result_rows) return result_rows - @classmethod - def merge_result_rows(cls, result_rows: List["Model"]) -> List["Model"]: - merged_rows = [] - for index, model in enumerate(result_rows): - if index > 0 and model.pk == result_rows[index - 1].pk: - result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) - else: - merged_rows.append(model) - return merged_rows - - @classmethod - def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": - for field in one.__model_fields__.keys(): - # print(field, one.dict(), other.dict()) - if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), orm.models.Model - ): - setattr(other, field, getattr(one, field) + getattr(other, field)) - elif isinstance(getattr(one, field), orm.models.Model): - if getattr(one, field).pk == getattr(other, field).pk: - setattr( - other, - field, - cls.merge_two_instances( - getattr(one, field), getattr(other, field) - ), - ) - return other - async def create(self, **kwargs: Any) -> "Model": new_kwargs = dict(**kwargs) From becb914e557c197893c0089ab8bd9dcd7c77a250 Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 9 Aug 2020 13:27:53 +0200 Subject: [PATCH 37/62] refactor query and queryclause into separate classes --- .coverage | Bin 53248 -> 53248 bytes orm/queryset.py | 375 +++++++++++++++++++++++++++--------------------- 2 files changed, 215 insertions(+), 160 deletions(-) diff --git a/.coverage b/.coverage index 09e8343a34172aa3ca53982806eb1a61d15636a3..5d7b859a5b4294ba4300067c998729bd4a6b59ad 100644 GIT binary patch delta 148 zcmV;F0Biq%paX!Q1F$MD2QxY{HaammvoSB#P#$;y5BU%358e;c52p`s404+swW z4$=;|4wepovk?$m4wHC}s7Xf$1px_x4hI6DxBQRK`oG=oKKJ|fSNrbX_xt+y_v!Dq v?SAj8-Puxij(`1ozW3Di@49>6KXdJ$e(ukEfBxLvzWY@9 None: - self.model_cls = model_cls - self.filter_clauses = [] if filter_clauses is None else filter_clauses - self._select_related = [] if select_related is None else select_related - self.limit_count = limit_count + self.query_offset = offset + self.limit_count = limit_count + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.__table__ self.auto_related = [] self.used_aliases = [] @@ -67,16 +70,45 @@ class QuerySet: self.columns = None self.order_bys = None - def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": - return self.__class__(model_cls=owner) + def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: + self.columns = list(self.table.columns) + self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.select_from = self.table - @property - def database(self) -> databases.Database: - return self.model_cls.__database__ + for key in self.model_cls.__model_fields__: + if ( + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], orm.fields.ForeignKey + ) + and key not in self._select_related + ): + self._select_related = [key] + self._select_related - @property - def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ + start_params = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + self._extract_auto_required_relations(prev_model=start_params.prev_model) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) + + for item in self._select_related: + join_parameters = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + + for part in item.split("__"): + join_parameters = self._build_join_parameters(part, join_parameters) + + expr = sqlalchemy.sql.select(self.columns) + expr = expr.select_from(self.select_from) + + expr = self._apply_expression_modifiers(expr) + + # print(expr.compile(compile_kwargs={"literal_binds": True})) + self._reset_query_parameters() + + return expr, self._select_related @staticmethod def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: @@ -89,6 +121,25 @@ class QuerySet: def prefixed_table_name(alias: str, name: str) -> text: return text(f"{name} {alias}_{name}") + @staticmethod + def _field_is_a_foreign_key_and_no_circular_reference( + field: BaseField, field_name: str, rel_part: str + ) -> bool: + return isinstance(field, ForeignKey) and field_name not in rel_part + + def _field_qualifies_to_deeper_search( + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + ) -> bool: + prev_part_of_related = "__".join(rel_part.split("__")[:-1]) + partial_match = any( + [x.startswith(prev_part_of_related) for x in self._select_related] + ) + already_checked = any([x.startswith(rel_part) for x in self.auto_related]) + return ( + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested + def on_clause( self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: @@ -139,25 +190,6 @@ class QuerySet: prev_model = model_cls return JoinParameters(prev_model, previous_alias, from_table, model_cls) - @staticmethod - def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str - ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part - - def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str - ) -> bool: - prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any( - [x.startswith(prev_part_of_related) for x in self._select_related] - ) - already_checked = any([x.startswith(rel_part) for x in self.auto_related]) - return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested - def _extract_auto_required_relations( self, prev_model: Type["Model"], @@ -221,130 +253,21 @@ class QuerySet: self.auto_related = [] self.used_aliases = [] - def build_select_expression(self) -> sqlalchemy.sql.select: - self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] - self.select_from = self.table - for key in self.model_cls.__model_fields__: - if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], orm.fields.ForeignKey - ) - and key not in self._select_related - ): - self._select_related = [key] + self._select_related +class QueryClause: + def __init__( + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + ) -> None: - start_params = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - self._extract_auto_required_relations(prev_model=start_params.prev_model) - self._include_auto_related_models() - self._select_related.sort(key=lambda item: (-len(item), item)) + self._select_related = select_related + self.filter_clauses = filter_clauses - for item in self._select_related: - join_parameters = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) + self.model_cls = model_cls + self.table = self.model_cls.__table__ - for part in item.split("__"): - join_parameters = self._build_join_parameters(part, join_parameters) - - expr = sqlalchemy.sql.select(self.columns) - expr = expr.select_from(self.select_from) - - expr = self._apply_expression_modifiers(expr) - - # print(expr.compile(compile_kwargs={"literal_binds": True})) - self._reset_query_parameters() - - return expr - - def _determine_filter_target_table( - self, related_parts: List[str], select_related: List[str] - ) -> Tuple[List[str], str, "Model"]: - - table_prefix = "" - model_cls = self.model_cls - select_related = [relation for relation in select_related] - - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - previous_table = model_cls.__tablename__ - for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager - table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to - previous_table = current_table - return select_related, table_prefix, model_cls - - def _compile_clause( - self, - clause: sqlalchemy.sql.expression.BinaryExpression, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - modifiers: Dict, - ) -> sqlalchemy.sql.expression.TextClause: - for modifier, modifier_value in modifiers.items(): - clause.modifiers[modifier] = modifier_value - - clause_text = str( - clause.compile( - dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True}, - ) - ) - alias = f"{table_prefix}_" if table_prefix else "" - aliased_name = f"{alias}{table.name}.{column.name}" - clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) - clause = text(clause_text) - return clause - - def _escape_characters_in_clause( - self, op: str, value: Union[str, "Model"] - ) -> Tuple[str, bool]: - has_escaped_character = False - - if op in ["contains", "icontains"]: - if isinstance(value, orm.Model): - raise QueryDefinitionError( - "You cannot use contains and icontains with instance of the Model" - ) - - has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS if c in value) - - if has_escaped_character: - # enable escape modifier - for char in self.ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - value = f"%{value}%" - - return value, has_escaped_character - - @staticmethod - def _extract_operator_field_and_related( - parts: List[str], - ) -> Tuple[str, str, Optional[List]]: - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] - - return op, field_name, related_parts - - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + def filter( # noqa: A003 + self, **kwargs: Any + ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -395,9 +318,141 @@ class QuerySet: table_prefix, modifiers={"escape": "\\" if has_escaped_character else None}, ) - filter_clauses.append(clause) + return filter_clauses, select_related + + def _determine_filter_target_table( + self, related_parts: List[str], select_related: List[str] + ) -> Tuple[List[str], str, "Model"]: + + table_prefix = "" + model_cls = self.model_cls + select_related = [relation for relation in select_related] + + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + previous_table = model_cls.__tablename__ + for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join(previous_table, current_table) + model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + return select_related, table_prefix, model_cls + + def _compile_clause( + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) + clause = text(clause_text) + return clause + + @staticmethod + def _escape_characters_in_clause( + op: str, value: Union[str, "Model"] + ) -> Tuple[str, bool]: + has_escaped_character = False + + if op in ["contains", "icontains"]: + if isinstance(value, orm.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + + has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" + + return value, has_escaped_character + + @staticmethod + def _extract_operator_field_and_related( + parts: List[str], + ) -> Tuple[str, str, Optional[List]]: + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + return op, field_name, related_parts + + +class QuerySet: + def __init__( + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + ) -> None: + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + self.order_bys = None + + def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": + return self.__class__(model_cls=owner) + + @property + def database(self) -> databases.Database: + return self.model_cls.__database__ + + @property + def table(self) -> sqlalchemy.Table: + return self.model_cls.__table__ + + def build_select_expression(self) -> sqlalchemy.sql.select: + qry = Query( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + offset=self.query_offset, + limit_count=self.limit_count, + ) + exp, self._select_related = qry.build_select_expression() + return exp + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + qryclause = QueryClause( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + ) + filter_clauses, select_related = qryclause.filter(**kwargs) + return self.__class__( model_cls=self.model_cls, filter_clauses=filter_clauses, From ace348e172ff5681bc5f228c77e4da25f85a8b2a Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 15:27:10 +0200 Subject: [PATCH 38/62] refactored reverse relation registration into the metaclass --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 37 +++++---------------- orm/models.py | 58 +++++++++++++++++++++++++-------- tests/test_same_table_joins.py | 25 +++++++------- 4 files changed, 65 insertions(+), 55 deletions(-) diff --git a/.coverage b/.coverage index 5d7b859a5b4294ba4300067c998729bd4a6b59ad..04d6ff895dc3fdaa51adcd2b47292b0028b57483 100644 GIT binary patch delta 166 zcmV;X09pTlpaX!Q1F$A93o$VuF*Q0eHaajclQA!Fld3OO2rM8cEiG_lVzd4)6Ho!F zvl4)94iLJ{n|I#4=RKeQ&o_9!&UcfKj?7i;|LyK?UqAQfy+42MZr^>X{PHmeI0gg( z333L0dn@1aEDKP-w*FL(wZ?;YD-`#KB-f!>q+wGsMz3s1iy8G|_ U_I)P>1OW+9ldX>NrBvl4)9 z4iIkh=AAe1dC$-P=Nmj<=R1>+j?7Z9|F^rpef`{@_x}93yM6bm^2^5@H3kF$32X*_ zy_IkI@%x|4v)Rr4n#2HnU;F6#zS%ape|Nuid%wNcZ?}K4_O`$7>F&Sx+xGz~1q1;J RR+F%gBLl4e2D8(TEkL^@RBQkM diff --git a/orm/fields.py b/orm/fields.py index f101346..a2dca94 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -51,7 +51,7 @@ class BaseField: @property def is_required(self) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not self.nullable and not self.has_default and not self.is_auto_primary_key ) @property @@ -204,12 +204,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -229,7 +229,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Union["Model", List["Model"]]: if isinstance(value, orm.models.Model) and not isinstance(value, self.to): @@ -261,7 +261,6 @@ class ForeignKey(BaseField): return model def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: - child_model_name = self.related_name or child.get_name() + "s" model._orm_relationship_manager.add_relation( model.__class__.__name__.lower(), child.__class__.__name__.lower(), @@ -269,23 +268,3 @@ class ForeignKey(BaseField): child, virtual=self.virtual, ) - - if ( - child_model_name not in model.__fields__ - and child.get_name() not in model.__fields__ - ): - self.register_reverse_model_fields(model, child, child_model_name) - - @staticmethod - def register_reverse_model_fields( - model: "Model", child: "Model", child_model_name: str - ) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( - child.__class__, name=child_model_name, virtual=True - ) diff --git a/orm/models.py b/orm/models.py index d41b09c..f7a6e6a 100644 --- a/orm/models.py +++ b/orm/models.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Callable, Dict, Set import databases +from pydantic.fields import ModelField import orm.queryset as qry from orm.exceptions import ModelDefinitionError @@ -41,8 +42,35 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> ) +def expand_reverse_relationships(model: Type["Model"]): + for field_name, model_field in model.__model_fields__.items(): + if isinstance(model_field, ForeignKey): + child_model_name = model_field.related_name or model.__name__.lower() + 's' + parent_model = model_field.to + child = model + if ( + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ + ): + register_reverse_model_fields(parent_model, child, child_model_name) + + +def register_reverse_model_fields( + model: Type["Model"], child: Type["Model"], child_model_name: str +) -> None: + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child, name=child_model_name, virtual=True + ) + + def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] @@ -100,14 +128,16 @@ class ModelMetaclass(type): attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) - attrs["__model_fields__"] = model_fields + attrs["__model_fields__"] = model_fields attrs["_orm_relationship_manager"] = relationship_manager new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) + expand_reverse_relationships(new_model) + return new_model @@ -168,9 +198,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): item = getattr(self.values, key, None) if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) + item is not None + and self._is_conversion_to_json_needed(key) + and isinstance(item, str) ): try: item = json.loads(item) @@ -186,7 +216,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.values is not None and other.values is not None and self.pk == other.pk ) def __repr__(self) -> str: # pragma no cover @@ -242,7 +272,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): related_names = set() for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + field.type_, pydantic.BaseModel ): related_names.add(name) return related_names @@ -274,7 +304,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): for field in one.__model_fields__.keys(): # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), Model + getattr(one, field), Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), Model): @@ -296,10 +326,10 @@ class Model(FakePydantic): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} @@ -357,8 +387,8 @@ class Model(FakePydantic): self_fields.pop(self.__pkname__) expr = ( self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) ) result = await self.__database__.execute(expr) return result diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 66f6ad1..8c67f4b 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -1,3 +1,5 @@ +import asyncio + import databases import pytest import sqlalchemy @@ -59,16 +61,17 @@ class Teacher(orm.Model): category = orm.ForeignKey(Category, nullable=True) +@pytest.fixture(scope='module') +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + loop.close() + + @pytest.fixture(autouse=True, scope="module") -def create_test_database(): +async def create_test_database(): engine = sqlalchemy.create_engine(DATABASE_URL) metadata.create_all(engine) - yield - metadata.drop_all(engine) - - -@pytest.fixture() -async def init_relation(): department = await Department.objects.create(id=1, name="Math Department") class1 = await SchoolClass.objects.create(name="Math", department=department) category = await Category.objects.create(name="Foreign") @@ -77,13 +80,11 @@ async def init_relation(): await Student.objects.create(name="Jack", category=category2, schoolclass=class1) await Teacher.objects.create(name="Joe", category=category2, schoolclass=class1) yield - engine = sqlalchemy.create_engine(DATABASE_URL) metadata.drop_all(engine) - metadata.create_all(engine) @pytest.mark.asyncio -async def test_model_multiple_instances_of_same_table_in_schema(init_relation): +async def test_model_multiple_instances_of_same_table_in_schema(): async with database: classes = await SchoolClass.objects.select_related( ["teachers__category", "students"] @@ -102,7 +103,7 @@ async def test_model_multiple_instances_of_same_table_in_schema(init_relation): @pytest.mark.asyncio -async def test_right_tables_join(init_relation): +async def test_right_tables_join(): async with database: classes = await SchoolClass.objects.select_related( ["teachers__category", "students"] @@ -115,7 +116,7 @@ async def test_right_tables_join(init_relation): @pytest.mark.asyncio -async def test_multiple_reverse_related_objects(init_relation): +async def test_multiple_reverse_related_objects(): async with database: classes = await SchoolClass.objects.select_related( ["teachers__category", "students"] From 8e19a5b127a8337e6642a680b5d6deb2081e412a Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 15:44:05 +0200 Subject: [PATCH 39/62] add clean script --- scripts/clean.sh | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 scripts/clean.sh diff --git a/scripts/clean.sh b/scripts/clean.sh new file mode 100644 index 0000000..38b8eb7 --- /dev/null +++ b/scripts/clean.sh @@ -0,0 +1,19 @@ +#!/bin/sh -e +PACKAGE="orm" +if [ -d 'dist' ] ; then + rm -r dist +fi +if [ -d 'site' ] ; then + rm -r site +fi +if [ -d 'htmlcov' ] ; then + rm -r htmlcov +fi +if [ -d "${PACKAGE}.egg-info" ] ; then + rm -r "${PACKAGE}.egg-info" +fi +find ${PACKAGE} -type f -name "*.py[co]" -delete +find ${PACKAGE} -type d -name __pycache__ -delete + +find tests -type f -name "*.py[co]" -delete +find tests -type d -name __pycache__ -delete \ No newline at end of file From 704e83fed0e6ca75cb5b4c859b55648900b76550 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 17:18:05 +0200 Subject: [PATCH 40/62] refactor required field in model fields into decorator --- .coverage | Bin 53248 -> 53248 bytes orm/fields.py | 57 +++++++++++++++++++++++++++----------------------- orm/models.py | 40 +++++++++++++++++------------------ 3 files changed, 51 insertions(+), 46 deletions(-) diff --git a/.coverage b/.coverage index 04d6ff895dc3fdaa51adcd2b47292b0028b57483..0c8e0091151cb63f7feb13a3d4d9a61ed1db424a 100644 GIT binary patch delta 116 zcmV-)0E_>CpaX!Q1F$MD2RAw~GdeLfvoSB#P#$*x5BU%358e;c52+7t4<8Q*4*(AG z4$2O&4vr3Yvk?$e4x@OE0Rg>}d5JPKik1ar3JTQp> delta 115 zcmV-(0F3{DpaX!Q1F$MD2Q@k}HaajcvoSB#P#$;y5BU%358e;c52_Dv4 None: + self._required = list(args) + + def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: + old_init = model_field_class.__init__ + model_field_class._old_init = old_init + + def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None: + super(instance.__class__, instance).__init__(*args, **kwargs) + for arg in self._required: + if arg not in kwargs: + raise ModelDefinitionError( + f"{instance.__class__.__name__} field requires parameter: {arg}" + ) + setattr(instance, arg, kwargs.pop(arg)) + + model_field_class.__init__ = __init__ + return model_field_class + + class BaseField: __type__ = None @@ -51,7 +71,7 @@ class BaseField: @property def is_required(self) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not self.nullable and not self.has_default and not self.is_auto_primary_key ) @property @@ -95,17 +115,10 @@ class BaseField: return value +@RequiredParams("length") class String(BaseField): __type__ = str - def __init__(self, *args: Any, **kwargs: Any) -> None: - if "length" not in kwargs: - raise ModelDefinitionError( - "Param length is required for String model field." - ) - self.length = kwargs.pop("length") - super().__init__(*args, **kwargs) - def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.String(self.length) @@ -173,18 +186,10 @@ class BigInteger(BaseField): return sqlalchemy.BigInteger() +@RequiredParams("length", "precision") class Decimal(BaseField): __type__ = decimal.Decimal - def __init__(self, *args: Any, **kwargs: Any) -> None: - if "length" not in kwargs or "precision" not in kwargs: - raise ModelDefinitionError( - "Params length and precision are required for Decimal model field." - ) - self.length = kwargs.pop("length") - self.precision = kwargs.pop("precision") - super().__init__(*args, **kwargs) - def get_column_type(self) -> sqlalchemy.Column: return sqlalchemy.DECIMAL(self.length, self.precision) @@ -204,12 +209,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -229,7 +234,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Union["Model", List["Model"]]: if isinstance(value, orm.models.Model) and not isinstance(value, self.to): diff --git a/orm/models.py b/orm/models.py index f7a6e6a..94047d2 100644 --- a/orm/models.py +++ b/orm/models.py @@ -6,7 +6,6 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Callable, Dict, Set import databases -from pydantic.fields import ModelField import orm.queryset as qry from orm.exceptions import ModelDefinitionError @@ -15,6 +14,7 @@ from orm.relations import RelationshipManager import pydantic from pydantic import BaseConfig, BaseModel, create_model +from pydantic.fields import ModelField import sqlalchemy @@ -42,21 +42,21 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> ) -def expand_reverse_relationships(model: Type["Model"]): - for field_name, model_field in model.__model_fields__.items(): +def expand_reverse_relationships(model: Type["Model"]) -> None: + for model_field in model.__model_fields__.values(): if isinstance(model_field, ForeignKey): - child_model_name = model_field.related_name or model.__name__.lower() + 's' + child_model_name = model_field.related_name or model.__name__.lower() + "s" parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: model.__fields__[child_model_name] = ModelField( name=child_model_name, @@ -70,7 +70,7 @@ def register_reverse_model_fields( def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: pkname: Optional[str] = None columns: List[sqlalchemy.Column] = [] @@ -198,9 +198,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): item = getattr(self.values, key, None) if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) + item is not None + and self._is_conversion_to_json_needed(key) + and isinstance(item, str) ): try: item = json.loads(item) @@ -216,7 +216,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.values is not None and other.values is not None and self.pk == other.pk ) def __repr__(self) -> str: # pragma no cover @@ -272,7 +272,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): related_names = set() for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + field.type_, pydantic.BaseModel ): related_names.add(name) return related_names @@ -304,7 +304,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): for field in one.__model_fields__.keys(): # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), Model + getattr(one, field), Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), Model): @@ -326,10 +326,10 @@ class Model(FakePydantic): @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} @@ -387,8 +387,8 @@ class Model(FakePydantic): self_fields.pop(self.__pkname__) expr = ( self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) ) result = await self.__database__.execute(expr) return result From 867fc691f70f97b425470bf0f92e3294ea8d204d Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 17:34:19 +0200 Subject: [PATCH 41/62] refactor fields into a package --- .coverage | Bin 53248 -> 53248 bytes orm/__init__.py | 3 +- orm/fields.py | 275 --------------------------------- orm/fields/__init__.py | 31 ++++ orm/fields/base.py | 107 +++++++++++++ orm/fields/foreign_key.py | 91 +++++++++++ orm/fields/model_fields.py | 86 +++++++++++ orm/models.py | 3 +- orm/queryset.py | 6 +- orm/relations.py | 2 +- tests/test_fastapi_usage.py | 3 +- tests/test_foreign_keys.py | 7 +- tests/test_same_table_joins.py | 11 +- 13 files changed, 335 insertions(+), 290 deletions(-) delete mode 100644 orm/fields.py create mode 100644 orm/fields/__init__.py create mode 100644 orm/fields/base.py create mode 100644 orm/fields/foreign_key.py create mode 100644 orm/fields/model_fields.py diff --git a/.coverage b/.coverage index 0c8e0091151cb63f7feb13a3d4d9a61ed1db424a..26e4f9cb550ed84f7390d87f992de099209690d8 100644 GIT binary patch delta 1173 zcmZ8hZA@EL7`}Jswx#X8=e>8pP-Y1ltvV5u@-S{^2!v=p%-`MRZ&gMG?cD0etrojsjCdQ#=vej7%NJ7;aMi=2QMB?sJgiRH;F zjt9w%Lr*T-dCrWORuXcQk{|5?Ya;y)i||gb)8qF#eYB0$c*$YaRTm(pOEZFd)bx}AF>RV$n3vDCd!5=0dTj>EAI))F5|Gve}C`T#J+(c*|bQ3dA9o=7TYF{4vx@r4sTz4bc1~ zI+X)yJJ6KWGmmpd8|u<1i{f#g&R^jETnZ{so8~xjz>$x}us=je*#YT)>FA$9POZA| zfaSk{BF*wdxmCh#3}j#fR$vZ(f+&0r7vXi#;S^rOzumlU7a3Qv(e0aQ(tQHc{S$UYs$p#ZD)9a zV?)MLVWn1CIX=;@=V=+{O~q!qR+xh~|8Y+w?{2A++4*erY1VGivBsf?Oov95%t-%s zAaP^t{JG|VRCZuS30+U8w$kajq>|X3+Klb&+UED}4F4*tm(zibzxx){#By>vn2HBN zdrQmH+1E0wZ{h|cJ;oC<1_w?}ul}H9>VvA+Ul*HDUOmf9j}7<+lHK8@nOJi2Msnr% zXyQa`AR+Jgc9Pl6tW-ctA1zqCd166Lr9KS2v}GtXQ0fW8vA!9Xy6BCqFYl;C5cQON zME~tlcU^ea6&S|^#wys4-RAH8G(9Z5tu9Xn)oL4e>JXz3;tw^JT7jz*P@; z&S(@Yg$^M-Dn+b!gE|ATx_(E6e)oQTBPs&rn}7-a5NvYF?kO%PCnSj}(8 zE5}PD9Oif(t@?|!*u*1R38>z+{ur;O6@ap=6(N3uM2^4mHM9&+r|@e-8Ll#LGz+Nq z)q)!sG+Bpjri;4%VJ(9@4ca2Dl!OY#e8JSyWy8meej;e4kOLV}B^2RDIDtMwCwd5v z!eZ_-_n5O1K?8^1pKhk14s_z}9h>otQHjs-QkjkH<(k3Zf3TQV?E$^(U~E;r5m?qi zh*wJ#SU6Kd%d5dDg3F1qik4M@_SHi5pD~NiovoBG|bMesEk2H(i~&8 zGqz@@60TVLr1xQKCO<9sW37=X6)0UTfQixijF-vOs zOJe-VRN|QGqf$-r!+m1?l diff --git a/orm/__init__.py b/orm/__init__.py index 39adee8..773ab78 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -6,13 +6,13 @@ from orm.fields import ( DateTime, Decimal, Float, - ForeignKey, Integer, JSON, String, Text, Time, ) +from orm.fields.foreign_key import ForeignKey from orm.models import Model __version__ = "0.0.1" @@ -28,7 +28,6 @@ __all__ = [ "Date", "Decimal", "Float", - "ForeignKey", "Model", "ModelDefinitionError", "ModelNotSet", diff --git a/orm/fields.py b/orm/fields.py deleted file mode 100644 index 980e334..0000000 --- a/orm/fields.py +++ /dev/null @@ -1,275 +0,0 @@ -import datetime -import decimal -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type, Union - -import orm -from orm.exceptions import ModelDefinitionError, RelationshipInstanceError - -from pydantic import BaseModel, Json - -import sqlalchemy - -if TYPE_CHECKING: # pragma no cover - from orm.models import Model - - -class RequiredParams: - def __init__(self, *args: str) -> None: - self._required = list(args) - - def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: - old_init = model_field_class.__init__ - model_field_class._old_init = old_init - - def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None: - super(instance.__class__, instance).__init__(*args, **kwargs) - for arg in self._required: - if arg not in kwargs: - raise ModelDefinitionError( - f"{instance.__class__.__name__} field requires parameter: {arg}" - ) - setattr(instance, arg, kwargs.pop(arg)) - - model_field_class.__init__ = __init__ - return model_field_class - - -class BaseField: - __type__ = None - - def __init__(self, *args: Any, **kwargs: Any) -> None: - name = kwargs.pop("name", None) - args = list(args) - if args: - if isinstance(args[0], str): - if name is not None: - raise ModelDefinitionError( - "Column name cannot be passed positionally and as a keyword." - ) - name = args.pop(0) - - self.name = name - self._populate_from_kwargs(kwargs) - - def _populate_from_kwargs(self, kwargs: Dict) -> None: - self.primary_key = kwargs.pop("primary_key", False) - self.autoincrement = kwargs.pop( - "autoincrement", self.primary_key and self.__type__ == int - ) - - self.nullable = kwargs.pop("nullable", not self.primary_key) - self.default = kwargs.pop("default", None) - self.server_default = kwargs.pop("server_default", None) - - self.index = kwargs.pop("index", None) - self.unique = kwargs.pop("unique", None) - - self.pydantic_only = kwargs.pop("pydantic_only", False) - if self.pydantic_only and self.primary_key: - raise ModelDefinitionError("Primary key column cannot be pydantic only.") - - @property - def is_required(self) -> bool: - return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key - ) - - @property - def default_value(self) -> Any: - default = self.default if self.default is not None else self.server_default - return default() if callable(default) else default - - @property - def has_default(self) -> bool: - return self.default is not None or self.server_default is not None - - @property - def is_auto_primary_key(self) -> bool: - if self.primary_key: - return self.autoincrement - return False - - def get_column(self, name: str = None) -> sqlalchemy.Column: - self.name = self.name or name - constraints = self.get_constraints() - return sqlalchemy.Column( - self.name, - self.get_column_type(), - *constraints, - primary_key=self.primary_key, - autoincrement=self.autoincrement, - nullable=self.nullable, - index=self.index, - unique=self.unique, - default=self.default, - server_default=self.server_default, - ) - - def get_column_type(self) -> sqlalchemy.types.TypeEngine: - raise NotImplementedError() # pragma: no cover - - def get_constraints(self) -> Optional[List]: - return [] - - def expand_relationship(self, value: Any, child: "Model") -> Any: - return value - - -@RequiredParams("length") -class String(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.String(self.length) - - -class Integer(BaseField): - __type__ = int - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Integer() - - -class Text(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Text() - - -class Float(BaseField): - __type__ = float - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Float() - - -class Boolean(BaseField): - __type__ = bool - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Boolean() - - -class DateTime(BaseField): - __type__ = datetime.datetime - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DateTime() - - -class Date(BaseField): - __type__ = datetime.date - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Date() - - -class Time(BaseField): - __type__ = datetime.time - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Time() - - -class JSON(BaseField): - __type__ = Json - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.JSON() - - -class BigInteger(BaseField): - __type__ = int - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.BigInteger() - - -@RequiredParams("length", "precision") -class Decimal(BaseField): - __type__ = decimal.Decimal - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DECIMAL(self.length, self.precision) - - -def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": - init_dict = {fk.__pkname__: pk or -1} - init_dict = { - **init_dict, - **{ - k: create_dummy_instance(v.to) - for k, v in fk.__model_fields__.items() - if isinstance(v, ForeignKey) and not v.nullable and not v.virtual - }, - } - return fk(**init_dict) - - -class ForeignKey(BaseField): - def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, - ) -> None: - super().__init__(nullable=nullable, name=name) - self.virtual = virtual - self.related_name = related_name - self.to = to - - @property - def __type__(self) -> Type[BaseModel]: - return self.to.__pydantic_model__ - - def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: - fk_string = self.to.__tablename__ + "." + self.to.__pkname__ - return [sqlalchemy.schema.ForeignKey(fk_string)] - - def get_column_type(self) -> sqlalchemy.Column: - to_column = self.to.__model_fields__[self.to.__pkname__] - return to_column.get_column_type() - - def expand_relationship( - self, value: Any, child: "Model" - ) -> Union["Model", List["Model"]]: - - if isinstance(value, orm.models.Model) and not isinstance(value, self.to): - raise RelationshipInstanceError( - f"Relationship error - expecting: {self.to.__name__}, " - f"but {value.__class__.__name__} encountered." - ) - - if isinstance(value, list) and not isinstance(value, self.to): - model = [self.expand_relationship(val, child) for val in value] - return model - - if isinstance(value, self.to): - model = value - elif isinstance(value, dict): - model = self.to(**value) - else: - if not isinstance(value, self.to.pk_type()): - raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} " - f"is of type {self.to.pk_type()} " - f"of type {self.__type__} " - f"while {type(value)} passed as a parameter." - ) - model = create_dummy_instance(fk=self.to, pk=value) - - self.add_to_relationship_registry(model, child) - - return model - - def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: - model._orm_relationship_manager.add_relation( - model.__class__.__name__.lower(), - child.__class__.__name__.lower(), - model, - child, - virtual=self.virtual, - ) diff --git a/orm/fields/__init__.py b/orm/fields/__init__.py new file mode 100644 index 0000000..75a1b7d --- /dev/null +++ b/orm/fields/__init__.py @@ -0,0 +1,31 @@ +from orm.fields.model_fields import ( + BigInteger, + Boolean, + Date, + DateTime, + Decimal, + String, + Integer, + Text, + Float, + Time, + JSON, +) +from orm.fields.foreign_key import ForeignKey +from orm.fields.base import BaseField + +__all__ = [ + "Decimal", + "BigInteger", + "Boolean", + "Date", + "DateTime", + "String", + "JSON", + "Integer", + "Text", + "Float", + "Time", + "ForeignKey", + "BaseField", +] diff --git a/orm/fields/base.py b/orm/fields/base.py new file mode 100644 index 0000000..32c0f13 --- /dev/null +++ b/orm/fields/base.py @@ -0,0 +1,107 @@ +from typing import Type, Any, Dict, Optional, List + +import sqlalchemy + +from orm import ModelDefinitionError + + +class RequiredParams: + def __init__(self, *args: str) -> None: + self._required = list(args) + + def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: + old_init = model_field_class.__init__ + model_field_class._old_init = old_init + + def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None: + super(instance.__class__, instance).__init__(*args, **kwargs) + for arg in self._required: + if arg not in kwargs: + raise ModelDefinitionError( + f"{instance.__class__.__name__} field requires parameter: {arg}" + ) + setattr(instance, arg, kwargs.pop(arg)) + + model_field_class.__init__ = __init__ + return model_field_class + + +class BaseField: + __type__ = None + + def __init__(self, *args: Any, **kwargs: Any) -> None: + name = kwargs.pop("name", None) + args = list(args) + if args: + if isinstance(args[0], str): + if name is not None: + raise ModelDefinitionError( + "Column name cannot be passed positionally and as a keyword." + ) + name = args.pop(0) + + self.name = name + self._populate_from_kwargs(kwargs) + + def _populate_from_kwargs(self, kwargs: Dict) -> None: + self.primary_key = kwargs.pop("primary_key", False) + self.autoincrement = kwargs.pop( + "autoincrement", self.primary_key and self.__type__ == int + ) + + self.nullable = kwargs.pop("nullable", not self.primary_key) + self.default = kwargs.pop("default", None) + self.server_default = kwargs.pop("server_default", None) + + self.index = kwargs.pop("index", None) + self.unique = kwargs.pop("unique", None) + + self.pydantic_only = kwargs.pop("pydantic_only", False) + if self.pydantic_only and self.primary_key: + raise ModelDefinitionError("Primary key column cannot be pydantic only.") + + @property + def is_required(self) -> bool: + return ( + not self.nullable and not self.has_default and not self.is_auto_primary_key + ) + + @property + def default_value(self) -> Any: + default = self.default if self.default is not None else self.server_default + return default() if callable(default) else default + + @property + def has_default(self) -> bool: + return self.default is not None or self.server_default is not None + + @property + def is_auto_primary_key(self) -> bool: + if self.primary_key: + return self.autoincrement + return False + + def get_column(self, name: str = None) -> sqlalchemy.Column: + self.name = self.name or name + constraints = self.get_constraints() + return sqlalchemy.Column( + self.name, + self.get_column_type(), + *constraints, + primary_key=self.primary_key, + autoincrement=self.autoincrement, + nullable=self.nullable, + index=self.index, + unique=self.unique, + default=self.default, + server_default=self.server_default, + ) + + def get_column_type(self) -> sqlalchemy.types.TypeEngine: + raise NotImplementedError() # pragma: no cover + + def get_constraints(self) -> Optional[List]: + return [] + + def expand_relationship(self, value: Any, child: "Model") -> Any: + return value diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py new file mode 100644 index 0000000..1496fc9 --- /dev/null +++ b/orm/fields/foreign_key.py @@ -0,0 +1,91 @@ +from typing import Type, List, Any, Union, TYPE_CHECKING + +import sqlalchemy +from pydantic import BaseModel + +import orm +from orm.exceptions import RelationshipInstanceError +from orm.fields.base import BaseField + +if TYPE_CHECKING: # pragma no cover + from orm.models import Model + + +def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": + init_dict = {fk.__pkname__: pk or -1} + init_dict = { + **init_dict, + **{ + k: create_dummy_instance(v.to) + for k, v in fk.__model_fields__.items() + if isinstance(v, ForeignKey) and not v.nullable and not v.virtual + }, + } + return fk(**init_dict) + + +class ForeignKey(BaseField): + def __init__( + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, + ) -> None: + super().__init__(nullable=nullable, name=name) + self.virtual = virtual + self.related_name = related_name + self.to = to + + @property + def __type__(self) -> Type[BaseModel]: + return self.to.__pydantic_model__ + + def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: + fk_string = self.to.__tablename__ + "." + self.to.__pkname__ + return [sqlalchemy.schema.ForeignKey(fk_string)] + + def get_column_type(self) -> sqlalchemy.Column: + to_column = self.to.__model_fields__[self.to.__pkname__] + return to_column.get_column_type() + + def expand_relationship( + self, value: Any, child: "Model" + ) -> Union["Model", List["Model"]]: + + if isinstance(value, orm.models.Model) and not isinstance(value, self.to): + raise RelationshipInstanceError( + f"Relationship error - expecting: {self.to.__name__}, " + f"but {value.__class__.__name__} encountered." + ) + + if isinstance(value, list) and not isinstance(value, self.to): + model = [self.expand_relationship(val, child) for val in value] + return model + + if isinstance(value, self.to): + model = value + elif isinstance(value, dict): + model = self.to(**value) + else: + if not isinstance(value, self.to.pk_type()): + raise RelationshipInstanceError( + f"Relationship error - ForeignKey {self.to.__name__} " + f"is of type {self.to.pk_type()} " + f"while {type(value)} passed as a parameter." + ) + model = create_dummy_instance(fk=self.to, pk=value) + + self.add_to_relationship_registry(model, child) + + return model + + def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: + model._orm_relationship_manager.add_relation( + model.__class__.__name__.lower(), + child.__class__.__name__.lower(), + model, + child, + virtual=self.virtual, + ) diff --git a/orm/fields/model_fields.py b/orm/fields/model_fields.py new file mode 100644 index 0000000..3ddae06 --- /dev/null +++ b/orm/fields/model_fields.py @@ -0,0 +1,86 @@ +import datetime +import decimal + +import sqlalchemy +from pydantic import Json + +from orm.fields.base import BaseField, RequiredParams + + +@RequiredParams("length") +class String(BaseField): + __type__ = str + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.String(self.length) + + +class Integer(BaseField): + __type__ = int + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.Integer() + + +class Text(BaseField): + __type__ = str + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.Text() + + +class Float(BaseField): + __type__ = float + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.Float() + + +class Boolean(BaseField): + __type__ = bool + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.Boolean() + + +class DateTime(BaseField): + __type__ = datetime.datetime + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.DateTime() + + +class Date(BaseField): + __type__ = datetime.date + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.Date() + + +class Time(BaseField): + __type__ = datetime.time + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.Time() + + +class JSON(BaseField): + __type__ = Json + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.JSON() + + +class BigInteger(BaseField): + __type__ = int + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.BigInteger() + + +@RequiredParams("length", "precision") +class Decimal(BaseField): + __type__ = decimal.Decimal + + def get_column_type(self) -> sqlalchemy.Column: + return sqlalchemy.DECIMAL(self.length, self.precision) diff --git a/orm/models.py b/orm/models.py index 94047d2..9fde6e7 100644 --- a/orm/models.py +++ b/orm/models.py @@ -9,7 +9,8 @@ import databases import orm.queryset as qry from orm.exceptions import ModelDefinitionError -from orm.fields import BaseField, ForeignKey +from orm import ForeignKey +from orm.fields.base import BaseField from orm.relations import RelationshipManager import pydantic diff --git a/orm/queryset.py b/orm/queryset.py index ae3cac2..ea6c264 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -13,9 +13,10 @@ from typing import ( import databases import orm +import orm.fields.foreign_key from orm import ForeignKey from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError -from orm.fields import BaseField +from orm.fields.base import BaseField import sqlalchemy from sqlalchemy import text @@ -79,7 +80,8 @@ class Query: if ( not self.model_cls.__model_fields__[key].nullable and isinstance( - self.model_cls.__model_fields__[key], orm.fields.ForeignKey + self.model_cls.__model_fields__[key], + orm.fields.foreign_key.ForeignKey, ) and key not in self._select_related ): diff --git a/orm/relations.py b/orm/relations.py index f541dfe..08f49fc 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -5,7 +5,7 @@ from random import choices from typing import Dict, List, TYPE_CHECKING, Union from weakref import proxy -from orm.fields import ForeignKey +from orm import ForeignKey if TYPE_CHECKING: # pragma no cover from orm.models import FakePydantic, Model diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 1c3d5dd..bef1607 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -4,6 +4,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient import orm +import orm.fields.foreign_key from tests.settings import DATABASE_URL app = FastAPI() @@ -28,7 +29,7 @@ class Item(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) - category = orm.ForeignKey(Category, nullable=True) + category = orm.fields.foreign_key.ForeignKey(Category, nullable=True) @app.post("/items/", response_model=Item) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 660f3d4..6b69507 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -3,6 +3,7 @@ import pytest import sqlalchemy import orm +import orm.fields.foreign_key from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from tests.settings import DATABASE_URL @@ -25,7 +26,7 @@ class Track(orm.Model): __database__ = database id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) + album = orm.fields.foreign_key.ForeignKey(Album) title = orm.String(length=100) position = orm.Integer() @@ -45,7 +46,7 @@ class Team(orm.Model): __database__ = database id = orm.Integer(primary_key=True) - org = orm.ForeignKey(Organisation) + org = orm.fields.foreign_key.ForeignKey(Organisation) name = orm.String(length=100) @@ -55,7 +56,7 @@ class Member(orm.Model): __database__ = database id = orm.Integer(primary_key=True) - team = orm.ForeignKey(Team) + team = orm.fields.foreign_key.ForeignKey(Team) email = orm.String(length=100) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 8c67f4b..4c27446 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -5,6 +5,7 @@ import pytest import sqlalchemy import orm +import orm.fields.foreign_key from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -27,7 +28,7 @@ class SchoolClass(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) - department = orm.ForeignKey(Department, nullable=False) + department = orm.fields.foreign_key.ForeignKey(Department, nullable=False) class Category(orm.Model): @@ -46,8 +47,8 @@ class Student(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) - schoolclass = orm.ForeignKey(SchoolClass) - category = orm.ForeignKey(Category, nullable=True) + schoolclass = orm.fields.foreign_key.ForeignKey(SchoolClass) + category = orm.fields.foreign_key.ForeignKey(Category, nullable=True) class Teacher(orm.Model): @@ -57,8 +58,8 @@ class Teacher(orm.Model): id = orm.Integer(primary_key=True) name = orm.String(length=100) - schoolclass = orm.ForeignKey(SchoolClass) - category = orm.ForeignKey(Category, nullable=True) + schoolclass = orm.fields.foreign_key.ForeignKey(SchoolClass) + category = orm.fields.foreign_key.ForeignKey(Category, nullable=True) @pytest.fixture(scope='module') From 7083b507122fdfd0a331f7d342792dab19d6de92 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 17:47:06 +0200 Subject: [PATCH 42/62] simplify adding relations --- .coverage | Bin 53248 -> 53248 bytes orm/fields/foreign_key.py | 33 +++++++++++++++------------------ orm/relations.py | 22 +++++++++++----------- tests/test_foreign_keys.py | 23 +++++++++++++++-------- 4 files changed, 41 insertions(+), 37 deletions(-) diff --git a/.coverage b/.coverage index 26e4f9cb550ed84f7390d87f992de099209690d8..80aea2e33e0baceb6a02d43b32739a4bda7b4474 100644 GIT binary patch delta 211 zcmV;^04)E2paX!Q1F$tO1vE7}H8Zn1FU(LRTL2IF59$xz57Q6G54I1X4~`F14+IYF z4&@Hl4z3QA4ss4#vk?$64wGGuQz<771OW+94(`pH_q=)MouB9Pf8Wmk{oDNOojdpT zp8wrB|DF8ZdHZ~m*p5FwU+vD8`rP;Xe&7CT-`(rq->1Ldw)=V8cklaWuKlY!$G`qP z-}lt@@3Q~5yT5(?+@JUU{JFb*_o?#B#~hPYk3$Nd?0hSf3wq%EkN`~aE<@~ delta 209 zcmV;?051Q4paX!Q1F$tO1v4``G&Hk1FU(LRT>uaH59$xz57Q6H54R7Z504L34+ReH z4(1Nn4zLcE4s;G(vk?$84wGAsQz#}51OW+74(`pH_q=)MouB9bzMcR3xB1sQckb;y z|GRVkJNdiw_W6_7jz2!Hc4td{?)!beZ-2G#?)C5Q)8B8~{k-kF_x&^1{?(o1U;m!( zd+Pdk+5g+!-@bnC&wGFV+}*zWRQcs&4wF=mLkd6H|L*RNdq?#LlcA3%3J3)R0SOWX L+W!x;(~m7cdF*cq diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index 1496fc9..a5e65dd 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Type, List, Any, Union, TYPE_CHECKING +from typing import Type, List, Any, Union, TYPE_CHECKING, Optional import sqlalchemy from pydantic import BaseModel @@ -12,9 +12,8 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": - init_dict = {fk.__pkname__: pk or -1} init_dict = { - **init_dict, + **{fk.__pkname__: pk or -1}, **{ k: create_dummy_instance(v.to) for k, v in fk.__model_fields__.items() @@ -26,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -51,8 +50,11 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" - ) -> Union["Model", List["Model"]]: + self, value: Any, child: "Model" + ) -> Optional[Union["Model", List["Model"]]]: + + if value is None: + return None if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( @@ -77,15 +79,10 @@ class ForeignKey(BaseField): ) model = create_dummy_instance(fk=self.to, pk=value) - self.add_to_relationship_registry(model, child) - - return model - - def add_to_relationship_registry(self, model: "Model", child: "Model") -> None: model._orm_relationship_manager.add_relation( - model.__class__.__name__.lower(), - child.__class__.__name__.lower(), model, child, virtual=self.virtual, ) + + return model diff --git a/orm/relations.py b/orm/relations.py index 08f49fc..fa0dbcf 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -16,7 +16,7 @@ def get_table_alias() -> str: def get_relation_config( - relation_type: str, table_name: str, field: ForeignKey + relation_type: str, table_name: str, field: ForeignKey ) -> Dict[str, str]: alias = get_table_alias() config = { @@ -37,7 +37,7 @@ class RelationshipManager: self._relations = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str ) -> None: if relations_key not in self._relations: self._relations[relations_key] = get_relation_config( @@ -56,15 +56,15 @@ class RelationshipManager: del self._relations[rel_type][model._orm_id] def add_relation( - self, - parent_name: str, - child_name: str, - parent: "FakePydantic", - child: "FakePydantic", - virtual: bool = False, + self, + parent: "FakePydantic", + child: "FakePydantic", + virtual: bool = False, ) -> None: parent_id = parent._orm_id child_id = child._orm_id + parent_name = parent.get_name() + child_name = child.get_name() if virtual: child_name, parent_name = parent_name, child_name child_id, parent_id = parent_id, child_id @@ -97,7 +97,7 @@ class RelationshipManager: return False def get( - self, relations_key: str, instance: "FakePydantic" + self, relations_key: str, instance: "FakePydantic" ) -> Union["Model", List["Model"]]: if relations_key in self._relations: if instance._orm_id in self._relations[relations_key]: @@ -108,8 +108,8 @@ class RelationshipManager: def resolve_relation_join(self, from_table: str, to_table: str) -> str: for relation_name, relation in self._relations.items(): if ( - relation["source_table"] == from_table - and relation["target_table"] == to_table + relation["source_table"] == from_table + and relation["target_table"] == to_table ): return self._relations[relation_name]["table_alias"] return "" diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 6b69507..7121e5a 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -74,6 +74,13 @@ async def test_wrong_query_foreign_key_type(): Track(title="The Error", album="wrong_pk_type") +@pytest.mark.asyncio +async def test_setting_explicitly_empty_relation(): + async with database: + track = Track(album=None, title="The Bird", position=1) + assert track.album is None + + @pytest.mark.asyncio async def test_model_crud(): async with database: @@ -146,8 +153,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -155,8 +162,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -198,8 +205,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -218,8 +225,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 From 4e91b3837bb731c7e811c7212262b07e043e08ae Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 17:58:34 +0200 Subject: [PATCH 43/62] some cleanup --- .coverage | Bin 53248 -> 53248 bytes .flake8 | 1 + orm/__init__.py | 1 + orm/fields/__init__.py | 12 ++++++------ orm/fields/base.py | 7 +++++-- orm/fields/foreign_key.py | 22 ++++++++++------------ orm/fields/model_fields.py | 2 +- orm/models.py | 14 ++++++-------- orm/queryset.py | 7 +++---- orm/relations.py | 15 ++++++--------- 10 files changed, 39 insertions(+), 42 deletions(-) diff --git a/.coverage b/.coverage index 80aea2e33e0baceb6a02d43b32739a4bda7b4474..52c345a480d2a4921903b810f9d261a68e88c907 100644 GIT binary patch delta 491 zcmYk0KS(1%7{zCIW@oe6nfWG06cm(Lh>5a#aKXdM!qSb6po0I!2m$dp5z)d6SKc-1 zR5=7~1Z_lFut=leU87>5*s9oB+G(TCC{ zul$s+@kyR$FYGc($Q@>966!8&B@&|{b)2hH=^S%90iy*a!<2gJt+s^XVLqPK}32!oT9#N~&a*W$|3a(1VDd z#sk`@kNQXVQ%wK`1kdmY4{!t5Z~}X<4I8iyzu^}w!w)FveLe7#0Kpo_vAAYc+`bd8r{!Ipj3a~qO3Gbr)H~-PDs9emya5)d Bb>sj5 delta 816 zcmYk3ZAep59LDe6y_??co^$Tam&%MXqEt+KalU3k-*rNQgrOKYDb$6VjIlTSEJS!C zh-I@6krG%maJZ@SLt-F0$U<~(7XlsS%Oug8O?Nv+3qSoIp7T8a{|`>j0M#=L`O0wY}Ifi5XjnMMLsLk=}zKpDqw6ineKPC9oF+SeKmQ?Ts1{B z;7*pOzV;(|3~JC`RRJ!J@Lmc)94lBra@@Y*T)qo49JoQcQb1v|(kwAOs2a2G1^a+ii zSLiXSM>O1kD{ui0!>6zdHe$jn(3!#0BvZiqT&1xwOVD~lb{n>A{jswgwr_HqtFu-| z;=1UK+00RIIAn7Lg9gXmiYrE*Ksa#~`DWyx{OiTGxVl_Hvw@ZZV8!V z@wVZEx0Ta@M3pFg*@=CSz5amxVA(At_DB78Y3S$7baoko&p5jnhrMB9W!kV;>X40= z691^-)?H%ao!w*?+g-~6zvvqmBWvDJ^=)gYSu#mtd?U`M;_CU-r78Ei(`>HXJY7Ym zrC=*JCEpy-%S%RoG%U#i;8fT=ry5yqFLBN4J&Wy~j1oKPO2uTt$ "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -50,7 +50,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: @@ -80,9 +80,7 @@ class ForeignKey(BaseField): model = create_dummy_instance(fk=self.to, pk=value) model._orm_relationship_manager.add_relation( - model, - child, - virtual=self.virtual, + model, child, virtual=self.virtual, ) return model diff --git a/orm/fields/model_fields.py b/orm/fields/model_fields.py index 3ddae06..813f190 100644 --- a/orm/fields/model_fields.py +++ b/orm/fields/model_fields.py @@ -4,7 +4,7 @@ import decimal import sqlalchemy from pydantic import Json -from orm.fields.base import BaseField, RequiredParams +from orm.fields.base import BaseField, RequiredParams # noqa I101 @RequiredParams("length") diff --git a/orm/models.py b/orm/models.py index 9fde6e7..22750eb 100644 --- a/orm/models.py +++ b/orm/models.py @@ -6,18 +6,16 @@ from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar from typing import Callable, Dict, Set import databases - -import orm.queryset as qry -from orm.exceptions import ModelDefinitionError -from orm import ForeignKey -from orm.fields.base import BaseField -from orm.relations import RelationshipManager - import pydantic +import sqlalchemy from pydantic import BaseConfig, BaseModel, create_model from pydantic.fields import ModelField -import sqlalchemy +import orm.queryset as qry # noqa I100 +from orm import ForeignKey +from orm.exceptions import ModelDefinitionError +from orm.fields.base import BaseField +from orm.relations import RelationshipManager relationship_manager = RelationshipManager() diff --git a/orm/queryset.py b/orm/queryset.py index ea6c264..5033f13 100644 --- a/orm/queryset.py +++ b/orm/queryset.py @@ -11,16 +11,15 @@ from typing import ( ) import databases +import sqlalchemy +from sqlalchemy import text -import orm +import orm # noqa I100 import orm.fields.foreign_key from orm import ForeignKey from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError from orm.fields.base import BaseField -import sqlalchemy -from sqlalchemy import text - if TYPE_CHECKING: # pragma no cover from orm.models import Model diff --git a/orm/relations.py b/orm/relations.py index fa0dbcf..7cf5ecf 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -16,7 +16,7 @@ def get_table_alias() -> str: def get_relation_config( - relation_type: str, table_name: str, field: ForeignKey + relation_type: str, table_name: str, field: ForeignKey ) -> Dict[str, str]: alias = get_table_alias() config = { @@ -37,7 +37,7 @@ class RelationshipManager: self._relations = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str ) -> None: if relations_key not in self._relations: self._relations[relations_key] = get_relation_config( @@ -56,10 +56,7 @@ class RelationshipManager: del self._relations[rel_type][model._orm_id] def add_relation( - self, - parent: "FakePydantic", - child: "FakePydantic", - virtual: bool = False, + self, parent: "FakePydantic", child: "FakePydantic", virtual: bool = False, ) -> None: parent_id = parent._orm_id child_id = child._orm_id @@ -97,7 +94,7 @@ class RelationshipManager: return False def get( - self, relations_key: str, instance: "FakePydantic" + self, relations_key: str, instance: "FakePydantic" ) -> Union["Model", List["Model"]]: if relations_key in self._relations: if instance._orm_id in self._relations[relations_key]: @@ -108,8 +105,8 @@ class RelationshipManager: def resolve_relation_join(self, from_table: str, to_table: str) -> str: for relation_name, relation in self._relations.items(): if ( - relation["source_table"] == from_table - and relation["target_table"] == to_table + relation["source_table"] == from_table + and relation["target_table"] == to_table ): return self._relations[relation_name]["table_alias"] return "" From d82340bcb1e6960a09f730d18b409fa098659115 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 18:11:34 +0200 Subject: [PATCH 44/62] refactors in foreign key --- .coverage | Bin 53248 -> 53248 bytes orm/__init__.py | 2 +- orm/fields/foreign_key.py | 51 ++++++++++++++++++++++++-------------- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/.coverage b/.coverage index 52c345a480d2a4921903b810f9d261a68e88c907..105fceb53d64948024e1f0c6aae03edf3748d6b8 100644 GIT binary patch delta 914 zcmYjOZAepL6ux)wy_;8*HM1}T&qFh;hFUJp-RA7vE>6e34hZ~hmaY*1aY}q@TrEB^=|qzu zkKzGgSq7lhcC5A2UUs6l2Cd`;F%JU5y37-24Lr$-q_Ij>%8Cu9b$FVZv$uA4AMa|< z%1dO z!=4&B)EjZsD}LLo(<^V{F}$BoOokn<%pNHyJA0;rv8dA%iP&5VpFJ|U7GXvu8h)un zlwUXJq{+!T$Naa6z!cgFai;+$!k%QPcgpLZnsZC9lyr3It9N&D-lxRs6z@{3H|1nG z%1tu)T)f&AuN$(cS*(MsHcy9Rsm}Fm$vPx`Nt=6=Gs@r=qn0L_a~rg(kE2Vb`qaow z{Ql|o17{Pdvrla9hjQYlERRXHXgW0WIGNUuXWn=}nH8Vxm|pCeP@*wu)Rl-h+?fxt z(NyEY=cBNe!COwc8JBe3w($L6=8~i|`@&s`=+(7-a!0`-sU}*SloP=&Ihb6WnK!+S z&yVkorIW5yJ)eW6MLB~XI(K@)ZGnl-E>_2xeE1g77EseBg$^z2Y}gx@Ext delta 627 zcmYk1KWGzC9LHboy?b|;d++@&CbZB{HJO?;T-9rgcIx2hU6(>F*!+{E*is`6Me3lz zMbV}9O%cH^1sxP6;1ER-aZzzlx;SVz?N}#Wtna8Coxb1q;meQr>1^Q627c;|(b7cm z_(ZXHdhDw+_*1+PH$_Dh`5u4I@A5j&vA^tdFX237T0SAS{ZZmzG%Mc(LoydIy#yGo zsMmmeJ!Q+^XH!IAREt2>*R8gDd9}n3&Z>TjFpQR>S=V=njZrHCpL|CjiDNXQ$dH^4 zJne)jUj-g9Fq%^=!|JtmLml?${+*0`9Dp32`jvhuZiyXH6(`vqY_Tt}DG>j_*ZBpW zR~nC){605I5_vo;GnI7cK%7%}MLh=pDhMVMW}E zfHEL<4$ml%meb{N?BPY#v(RiUH!jX!sw_5E)m2mOmJ7t#Ka~w7dc&0;N3#G62>!q> z?7%nJg17JlHsL<3!)>?$S78|zl*}2pPJzTM0evdM6fWgq@;+r@GMCz6+GR7Y@=#pc z7KWvoV_mBlW?0O7IBBRdV&t#Hnz2dLseN?N8gZzCS^E5ZVc^cY`m=Kd*U-bgu1#7* pl@w_ Tuple[Union["Model", List["Model"]], bool]: + if isinstance(value, list) and not isinstance(value, self.to): + model = [self.expand_relationship(val, child) for val in value] + return model, True + + if isinstance(value, self.to): + model = value + else: + model = self.to(**value) + return model, False + + def construct_model_from_pk(self, value: Any) -> "Model": + if not isinstance(value, self.to.pk_type()): + raise RelationshipInstanceError( + f"Relationship error - ForeignKey {self.to.__name__} " + f"is of type {self.to.pk_type()} " + f"while {type(value)} passed as a parameter." + ) + return create_dummy_instance(fk=self.to, pk=value) + def expand_relationship( self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: @@ -56,31 +78,22 @@ class ForeignKey(BaseField): if value is None: return None + is_sequence = False + if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( f"Relationship error - expecting: {self.to.__name__}, " f"but {value.__class__.__name__} encountered." ) - if isinstance(value, list) and not isinstance(value, self.to): - model = [self.expand_relationship(val, child) for val in value] - return model - - if isinstance(value, self.to): - model = value - elif isinstance(value, dict): - model = self.to(**value) + if isinstance(value, (dict, list, self.to)): + model, is_sequence = self.extract_model_from_sequence(value, child) else: - if not isinstance(value, self.to.pk_type()): - raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} " - f"is of type {self.to.pk_type()} " - f"while {type(value)} passed as a parameter." - ) - model = create_dummy_instance(fk=self.to, pk=value) + model = self.construct_model_from_pk(value) - model._orm_relationship_manager.add_relation( - model, child, virtual=self.virtual, - ) + if not is_sequence: + model._orm_relationship_manager.add_relation( + model, child, virtual=self.virtual + ) return model From 3e04646fd4a1bb64f6cd5a1cd3058e603b5cb1ae Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 18:32:48 +0200 Subject: [PATCH 45/62] refactors in fk --- .coverage | Bin 53248 -> 53248 bytes orm/fields/foreign_key.py | 45 +++++++++++++++++++------------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/.coverage b/.coverage index 105fceb53d64948024e1f0c6aae03edf3748d6b8..12fa709d6622e39bed47bdc45522a017c60c3638 100644 GIT binary patch delta 93 zcmV-j0HXhZpaX!Q1F$tO1u{80H8`_6FU(LRTL2IF59$xz57Q6G53>)S4~-8~4+9SE z4&@Hl4z3QA4ss4#vk?$64zpa2L;(~M2Lu5LCI{NZpX`5k_uD=0+>RBqvyV+c%eo=u delta 92 zcmV-i0HgnapaX!Q1F$tO1u!@|Ffg+^FU(LRTmTRG59$xz57Q6G53>)T4~`F14+IYG y4(1Nn4zCWC4s#A%vk?$74zpX1L;({K2Lu5LBnR5XpX`5k_qJQdj!v_)k4-?Mtsx8m diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index b5659b4..827ab19 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -25,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -50,36 +50,42 @@ class ForeignKey(BaseField): return to_column.get_column_type() def extract_model_from_sequence( - self, value: Any, child: "Model" - ) -> Tuple[Union["Model", List["Model"]], bool]: + self, value: Any, child: "Model" + ) -> Union["Model", List["Model"]]: if isinstance(value, list) and not isinstance(value, self.to): model = [self.expand_relationship(val, child) for val in value] - return model, True + return model if isinstance(value, self.to): model = value else: model = self.to(**value) - return model, False + self.register_relation(model, child) + return model - def construct_model_from_pk(self, value: Any) -> "Model": + def construct_model_from_pk(self, value: Any, child: "Model") -> "Model": if not isinstance(value, self.to.pk_type()): raise RelationshipInstanceError( f"Relationship error - ForeignKey {self.to.__name__} " f"is of type {self.to.pk_type()} " f"while {type(value)} passed as a parameter." ) - return create_dummy_instance(fk=self.to, pk=value) + model = create_dummy_instance(fk=self.to, pk=value) + self.register_relation(model, child) + return model + + def register_relation(self, model, child): + model._orm_relationship_manager.add_relation( + model, child, virtual=self.virtual + ) def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None - is_sequence = False - if isinstance(value, orm.models.Model) and not isinstance(value, self.to): raise RelationshipInstanceError( f"Relationship error - expecting: {self.to.__name__}, " @@ -87,13 +93,8 @@ class ForeignKey(BaseField): ) if isinstance(value, (dict, list, self.to)): - model, is_sequence = self.extract_model_from_sequence(value, child) + model = self.extract_model_from_sequence(value, child) else: - model = self.construct_model_from_pk(value) - - if not is_sequence: - model._orm_relationship_manager.add_relation( - model, child, virtual=self.virtual - ) + model = self.construct_model_from_pk(value, child) return model From 24b5649c565648196119b8e6ba455f2e801c32a8 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 18:56:40 +0200 Subject: [PATCH 46/62] refactor expanding of relationship into constructors --- .coverage | Bin 53248 -> 53248 bytes orm/fields/foreign_key.py | 60 ++++++++++++++++++-------------------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/.coverage b/.coverage index 12fa709d6622e39bed47bdc45522a017c60c3638..d8930bb67053fbd25ac27e8ab0d7f1591e81a002 100644 GIT binary patch delta 32 ocmZozz}&Ead4s7wtEr)tvGHar{o(|+ced~D%wycVx!>0T0J}#FaR2}S delta 32 ocmZozz}&Ead4s7wtC6LZsl{e1{o(|++qUoS^oefX-0$lE0J~%iZU6uP diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index 827ab19..0788ce2 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy from pydantic import BaseModel @@ -25,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -49,21 +49,21 @@ class ForeignKey(BaseField): to_column = self.to.__model_fields__[self.to.__pkname__] return to_column.get_column_type() - def extract_model_from_sequence( - self, value: Any, child: "Model" + def _extract_model_from_sequence( + self, value: List, child: "Model" ) -> Union["Model", List["Model"]]: - if isinstance(value, list) and not isinstance(value, self.to): - model = [self.expand_relationship(val, child) for val in value] - return model + return [self.expand_relationship(val, child) for val in value] - if isinstance(value, self.to): - model = value - else: - model = self.to(**value) + def _register_existing_model(self, value: "Model", child: "Model") -> "Model": + self.register_relation(value, child) + return value + + def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": + model = self.to(**value) self.register_relation(model, child) return model - def construct_model_from_pk(self, value: Any, child: "Model") -> "Model": + def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": if not isinstance(value, self.to.pk_type()): raise RelationshipInstanceError( f"Relationship error - ForeignKey {self.to.__name__} " @@ -74,27 +74,23 @@ class ForeignKey(BaseField): self.register_relation(model, child) return model - def register_relation(self, model, child): - model._orm_relationship_manager.add_relation( - model, child, virtual=self.virtual - ) + def register_relation(self, model: "Model", child: "Model") -> None: + model._orm_relationship_manager.add_relation(model, child, virtual=self.virtual) def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None - if isinstance(value, orm.models.Model) and not isinstance(value, self.to): - raise RelationshipInstanceError( - f"Relationship error - expecting: {self.to.__name__}, " - f"but {value.__class__.__name__} encountered." - ) - - if isinstance(value, (dict, list, self.to)): - model = self.extract_model_from_sequence(value, child) - else: - model = self.construct_model_from_pk(value, child) + constructors = { + f"{self.to.__name__}": self._register_existing_model, + "dict": self._construct_model_from_dict, + "list": self._extract_model_from_sequence, + } + model = constructors.get( + value.__class__.__name__, self._construct_model_from_pk + )(value, child) return model From 146dbea0151956f5724dd48314000c467e2b145d Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 19:03:02 +0200 Subject: [PATCH 47/62] refactor decorator to separate file --- .coverage | Bin 53248 -> 53248 bytes orm/fields/base.py | 23 +---------------------- orm/fields/foreign_key.py | 2 +- orm/fields/model_fields.py | 3 ++- orm/fields/required_decorator.py | 27 +++++++++++++++++++++++++++ 5 files changed, 31 insertions(+), 24 deletions(-) create mode 100644 orm/fields/required_decorator.py diff --git a/.coverage b/.coverage index d8930bb67053fbd25ac27e8ab0d7f1591e81a002..b5b4f932fd3d5ebc39d84ab25846aa7fd47a3799 100644 GIT binary patch delta 774 zcmZvaUr19?9LMkN-2J)RJ?GrrR*{)F|IB}(*)rXz9-^1RJ41+Nlsa;OkQ^#uHm$UC0+y>U|CLYP*Se-FBUm2aG~9TSCxR7pe~Wsw%2{ zHv&O#Lm)_74&&~7Hg(YLVX0&%gTuH4C!rtc6?%YVln#Hx8T=9U;eJ>PosbdU3!PN5 zMj*4-ia043c!{mz&_7SAWsGTAia=-&+|P+8fj+s z((!~i`c{i-TaSJ&E-g938!L0_GS0v>xrE73d#VG`F##Jir@PHMmcN&1Cg#Wmy@@u0CUQ06IM4HRw({Al%}_MbF~vAMGFNk| z``t}-ZeB7_VmExA>{+?nxVd_!_T@5b(NmVDfA(p7V}4cH+|jB4nTnhS5^g1ihpGE; lD|Nhr4Y3+4Noreoc8aW?|siH3tT3fmZpSaYi`c$qPw7Nio`B-u0a;T{+JucG)k`u zB#pi<0?T<(pqUXIkQA+M1S!IFg;+ALG757>&>t+;HzMikd!FZg--qAVGc5EB3y*A7 zdZ)`??XuZ4SJiW+RFxMwe2sUp3q80GHenjBKmhDuB55*0V#Gxxc~Mi95joGHQr^9v z&Ya8_6@d%VM$aDl)GJBh=$W=qtg=1fK(^o%LEyswhP2hK@PtO)c3!3p-U4=@BeAx4 z{CwnG?xb9z3+@WhByc`9$#3MqOUhtRDfE#$A5xOkepcM5$cxQEOl?!;(Z(nb4FoWE;zvE3*cQweeP;wN!;>4u9e&Ok)b?P{Sl% z!zL8qGrWf-cnCM41DY9CH5BbuMc#Nwof>pxrW5ns$Fj-j=EXZ1i^AgL3T^K%v>=m* zaq2~OKKtYL*VVPP+U{@b%i0Gl12$;@`k=qWMwm2{v^|=xoeb(^=0G;(OYE(XF3tW5 zPp*Fa9q7z&EU(k+HxqeQ_qr8 z{oC2(`sL2=8}%D2L;Nls%MR-9Ki;x$Ee*Y^&pgQ{Lv(K7sGwpW)n6}R{|chkmrPh# QGDh{!a!D`JvH4Sf0R3RjZ2$lO diff --git a/orm/fields/base.py b/orm/fields/base.py index 9d321f2..d405679 100644 --- a/orm/fields/base.py +++ b/orm/fields/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Type +from typing import Any, Dict, List, Optional, TYPE_CHECKING import sqlalchemy @@ -8,27 +8,6 @@ if TYPE_CHECKING: # pragma no cover from orm.models import Model -class RequiredParams: - def __init__(self, *args: str) -> None: - self._required = list(args) - - def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: - old_init = model_field_class.__init__ - model_field_class._old_init = old_init - - def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None: - super(instance.__class__, instance).__init__(*args, **kwargs) - for arg in self._required: - if arg not in kwargs: - raise ModelDefinitionError( - f"{instance.__class__.__name__} field requires parameter: {arg}" - ) - setattr(instance, arg, kwargs.pop(arg)) - - model_field_class.__init__ = __init__ - return model_field_class - - class BaseField: __type__ = None diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index 0788ce2..b32a887 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: # pragma no cover from orm.models import Model -def create_dummy_instance(fk: Type["Model"], pk: int = None) -> "Model": +def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": init_dict = { **{fk.__pkname__: pk or -1}, **{ diff --git a/orm/fields/model_fields.py b/orm/fields/model_fields.py index 813f190..4d841e1 100644 --- a/orm/fields/model_fields.py +++ b/orm/fields/model_fields.py @@ -4,7 +4,8 @@ import decimal import sqlalchemy from pydantic import Json -from orm.fields.base import BaseField, RequiredParams # noqa I101 +from orm.fields.base import BaseField # noqa I101 +from orm.fields.required_decorator import RequiredParams @RequiredParams("length") diff --git a/orm/fields/required_decorator.py b/orm/fields/required_decorator.py new file mode 100644 index 0000000..4deb597 --- /dev/null +++ b/orm/fields/required_decorator.py @@ -0,0 +1,27 @@ +from typing import Any, TYPE_CHECKING, Type + +from orm import ModelDefinitionError + +if TYPE_CHECKING: # pragma no cover + from orm.fields import BaseField + + +class RequiredParams: + def __init__(self, *args: str) -> None: + self._required = list(args) + + def __call__(self, model_field_class: Type["BaseField"]) -> Type["BaseField"]: + old_init = model_field_class.__init__ + model_field_class._old_init = old_init + + def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None: + super(instance.__class__, instance).__init__(*args, **kwargs) + for arg in self._required: + if arg not in kwargs: + raise ModelDefinitionError( + f"{instance.__class__.__name__} field requires parameter: {arg}" + ) + setattr(instance, arg, kwargs.pop(arg)) + + model_field_class.__init__ = __init__ + return model_field_class From 45653d36c7c9c902c345b848fa35972937384500 Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 19:43:42 +0200 Subject: [PATCH 48/62] refactori into packages --- .coverage | Bin 53248 -> 53248 bytes .../{required_decorator.py => decorators.py} | 0 orm/fields/model_fields.py | 2 +- orm/models.py | 405 ------------- orm/models/__init__.py | 5 + orm/models/fakepydantic.py | 195 ++++++ orm/models/metaclass.py | 132 ++++ orm/models/model.py | 93 +++ orm/queryset.py | 571 ------------------ orm/queryset/__init__.py | 0 orm/queryset/clause.py | 176 ++++++ orm/queryset/query.py | 228 +++++++ orm/queryset/queryset.py | 175 ++++++ 13 files changed, 1005 insertions(+), 977 deletions(-) rename orm/fields/{required_decorator.py => decorators.py} (100%) delete mode 100644 orm/models.py create mode 100644 orm/models/__init__.py create mode 100644 orm/models/fakepydantic.py create mode 100644 orm/models/metaclass.py create mode 100644 orm/models/model.py delete mode 100644 orm/queryset.py create mode 100644 orm/queryset/__init__.py create mode 100644 orm/queryset/clause.py create mode 100644 orm/queryset/query.py create mode 100644 orm/queryset/queryset.py diff --git a/.coverage b/.coverage index b5b4f932fd3d5ebc39d84ab25846aa7fd47a3799..2007c584c86b15bdaf4a5b56e47ec8e955810cb9 100644 GIT binary patch delta 1322 zcmZvce@I(b6vq=^nwq@#?#b`G#x~VxH!3!nRi}2$KgQT!i*c|~u+c|I(kz=v{1IDh zs`Zb;gwa}h86DKvmTrSV%L?f_tV+g4S=lC8*=(sL5w@YsE^0Rxr!L9vOH6RMf9^T& zeD3+)d%hRm@KtK~Ds`*DQc_N(maf}h|cI0db+53Qmp^d5Q>RkDAvU$LWX zCu?TbnQxderkgp!?4rM?!?Y`+RMgNVjT{zRtMK&e8i}3~Iq(~;Ivlmpv@hUtclDXO z2O4!`Lz6^HiF^f!JFN<_1_!JSVl{5F?I8_Pw$B|5di@@u<)8{bupZp?f|3$>NRNTb zi`jkdpu^ea=o9wTDe{7bvMqSr@vggjz~%4-8(?k?YV>?3&Q(t z`}D;+k?ikt6Hv=iBCpW1r^pJ4K$#ts=e>S!kmm`lp(znB^ePLN6$}tgljly`s&4D3 zmF4v+R4%tO(CY{WdUJ!7<_D{Y+-yB6CNJX_K?rPMM96DrSSZ&!*U~S)R49 z7nm^2GClAy^AaPaXXtU-O#zg~_z~{I=Wq>1;Vhg6FF4>;cn(U?&*)op8{I(Ts2?3e&B#O?Ut*0~P9m1oaCa{^ zu1pQYLM_=Vfz{LDm2y2PzoN(8=l7QDNP1Xz_q%K?`}f$N^9u{jq15k>Ba867mXJ0r zzHz?tjE1BJO`;)?Xr6FL)#MAQR|Hv)6L^ge$B+w50mLsXEdIXnSk)uVD6qL`bB(m zcDd`**kE*JE%{k0K7A&W9^BaUP43@|$fhG`?vLnHa&j$Q_b4+P{)NwuuY{I1PKwJ| zLYvsR_tVjMI{9=lGqtuDizg>-AMh;;QiedplqhKutLS$gCQ^@Hj;Gg#Hd~}}0m%(A z8HpNM_3{v_C?%<>^v=ZNt`{%Q;P9x2;-DSxzHoqnK{x>cc$;|NPFyxYwUq1?b6hUy z^|_$a<$_k53mQ!>sMV3PSBA@p(g@!DB~iKtV{i#B!Z{crN^M}II7-A)lUVz|ZK}b3 delta 754 zcmZvYZAep57{~ARUSBrvIcFONX=ZA^8(NvU1xfU=Efoz`&6=V^M7Nt$lQ~}&2t|LlA^cF%y|?L1iSlt3)oH^LFNkKK1E5zyJThd49*%$JqLqm)V&? zCr{aPGG}Bd3a*nvAJG76M0w~a+=P?xF4VvTunQdE888ARCvtOKr(G_q;(~G%cvhW& z@2ewdIqp?w&@#eG7BQ3&&#U8bx<v6>K2YpAQMG*#4Q+_I!{ zxJQ#3oD$AZTHLRR!p|>M|_TfA&O=P+#WJUQX25!MAcn@BNiNFWu&>|Q>gTMq-fa2!4E+SLS zVS`RdD{zl4gOw|pdfcx|#$GjKWoH;8fp{W=Sd`812T-Ds(Gb@Zr?Vi9xrXEPNAQG( zW4TmDdjLcslq?jU%oGI}{4DHEe-ey2q+n^u3B~x69AVPpu;z-r{AN?P5GyC~MLE9E z9vv$q<7wIW2hkz^dHTz}wx+Uf|6I0zN6CmZ1qe5_A8&k zwsiXs-C}XijLGACeAp&b9lPPqaVB{M&!|Zl_3mzL=YMi-FP&fYds}?DxUIv+C{P;? ceUwQ4<0`VZ_1a=dGK53l1d>LChaA`c0w}T~bN~PV diff --git a/orm/fields/required_decorator.py b/orm/fields/decorators.py similarity index 100% rename from orm/fields/required_decorator.py rename to orm/fields/decorators.py diff --git a/orm/fields/model_fields.py b/orm/fields/model_fields.py index 4d841e1..f14391e 100644 --- a/orm/fields/model_fields.py +++ b/orm/fields/model_fields.py @@ -5,7 +5,7 @@ import sqlalchemy from pydantic import Json from orm.fields.base import BaseField # noqa I101 -from orm.fields.required_decorator import RequiredParams +from orm.fields.decorators import RequiredParams @RequiredParams("length") diff --git a/orm/models.py b/orm/models.py deleted file mode 100644 index 22750eb..0000000 --- a/orm/models.py +++ /dev/null @@ -1,405 +0,0 @@ -import copy -import inspect -import json -import uuid -from typing import Any, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar -from typing import Callable, Dict, Set - -import databases -import pydantic -import sqlalchemy -from pydantic import BaseConfig, BaseModel, create_model -from pydantic.fields import ModelField - -import orm.queryset as qry # noqa I100 -from orm import ForeignKey -from orm.exceptions import ModelDefinitionError -from orm.fields.base import BaseField -from orm.relations import RelationshipManager - -relationship_manager = RelationshipManager() - - -def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: - pydantic_fields = { - field_name: ( - base_field.__type__, - ... if base_field.is_required else base_field.default_value, - ) - for field_name, base_field in object_dict.items() - if isinstance(base_field, BaseField) - } - return pydantic_fields - - -def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: - child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s" - reverse_name = field.related_name or child_relation_name - relation_name = name.lower().title() + "_" + field.to.get_name() - relationship_manager.add_relation_type( - relation_name, reverse_name, field, table_name - ) - - -def expand_reverse_relationships(model: Type["Model"]) -> None: - for model_field in model.__model_fields__.values(): - if isinstance(model_field, ForeignKey): - child_model_name = model_field.related_name or model.__name__.lower() + "s" - parent_model = model_field.to - child = model - if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ - ): - register_reverse_model_fields(parent_model, child, child_model_name) - - -def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str -) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( - child, name=child_model_name, virtual=True - ) - - -def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str -) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: - pkname: Optional[str] = None - columns: List[sqlalchemy.Column] = [] - model_fields: Dict[str, BaseField] = {} - - for field_name, field in object_dict.items(): - if isinstance(field, BaseField): - model_fields[field_name] = field - if not field.pydantic_only: - if field.primary_key: - pkname = field_name - if isinstance(field, ForeignKey): - register_relation_on_build(table_name, field, name) - columns.append(field.get_column(field_name)) - return pkname, columns, model_fields - - -def get_pydantic_base_orm_config() -> Type[BaseConfig]: - class Config(BaseConfig): - orm_mode = True - - return Config - - -class ModelMetaclass(type): - def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) - - if attrs.get("__abstract__"): - return new_model - - tablename = attrs["__tablename__"] - metadata = attrs["__metadata__"] - - # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, tablename - ) - attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) - attrs["__columns__"] = columns - attrs["__pkname__"] = pkname - - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") - - # pydantic model creation - pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **pydantic_fields - ) - attrs["__pydantic_fields__"] = pydantic_fields - attrs["__pydantic_model__"] = pydantic_model - attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) - attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) - attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) - - attrs["__model_fields__"] = model_fields - attrs["_orm_relationship_manager"] = relationship_manager - - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) - - expand_reverse_relationships(new_model) - - return new_model - - -class FakePydantic(list, metaclass=ModelMetaclass): - # FakePydantic inherits from list in order to be treated as - # request.Body parameter in fastapi routes, - # inheriting from pydantic.BaseModel causes metaclass conflicts - __abstract__ = True - if TYPE_CHECKING: # pragma no cover - __model_fields__: Dict[str, TypeVar[BaseField]] - __table__: sqlalchemy.Table - __fields__: Dict[str, pydantic.fields.ModelField] - __pydantic_model__: Type[BaseModel] - __pkname__: str - __tablename__: str - __metadata__: sqlalchemy.MetaData - __database__: databases.Database - _orm_relationship_manager: RelationshipManager - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._orm_id: str = uuid.uuid4().hex - self._orm_saved: bool = False - self.values: Optional[BaseModel] = None - - if "pk" in kwargs: - kwargs[self.__pkname__] = kwargs.pop("pk") - kwargs = { - k: self.__model_fields__[k].expand_relationship(v, self) - for k, v in kwargs.items() - } - self.values = self.__pydantic_model__(**kwargs) - - def __del__(self) -> None: - self._orm_relationship_manager.deregister(self) - - def __setattr__(self, key: str, value: Any) -> None: - if key in self.__fields__: - if self._is_conversion_to_json_needed(key) and not isinstance(value, str): - try: - value = json.dumps(value) - except TypeError: # pragma no cover - pass - - value = self.__model_fields__[key].expand_relationship(value, self) - - relation_key = self.__class__.__name__.title() + "_" + key - if not self._orm_relationship_manager.contains(relation_key, self): - setattr(self.values, key, value) - else: - super().__setattr__(key, value) - - def __getattribute__(self, key: str) -> Any: - if key != "__fields__" and key in self.__fields__: - relation_key = self.__class__.__name__.title() + "_" + key - if self._orm_relationship_manager.contains(relation_key, self): - return self._orm_relationship_manager.get(relation_key, self) - - item = getattr(self.values, key, None) - if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) - ): - try: - item = json.loads(item) - except TypeError: # pragma no cover - pass - return item - return super().__getattribute__(key) - - def __eq__(self, other: "Model") -> bool: - return self.values.dict() == other.values.dict() - - def __same__(self, other: "Model") -> bool: - if self.__class__ != other.__class__: # pragma no cover - return False - return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk - ) - - def __repr__(self) -> str: # pragma no cover - return self.values.__repr__() - - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate - - @classmethod - def get_name(cls, title: bool = False, lower: bool = True) -> str: - name = cls.__name__ - if lower: - name = name.lower() - if title: - name = name.title() - return name - - @property - def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] - - @classmethod - def pk_type(cls) -> Any: - return cls.__model_fields__[cls.__pkname__].__type__ - - def dict(self) -> Dict: # noqa: A003 - dict_instance = self.values.dict() - for field in self._extract_related_names(): - nested_model = getattr(self, field) - if isinstance(nested_model, list): - dict_instance[field] = [x.dict() for x in nested_model] - else: - dict_instance[field] = ( - nested_model.dict() if nested_model is not None else {} - ) - return dict_instance - - def from_dict(self, value_dict: Dict) -> None: - for key, value in value_dict.items(): - setattr(self, key, value) - - def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json - - def _extract_own_model_fields(self) -> Dict: - related_names = self._extract_related_names() - self_fields = {k: v for k, v in self.dict().items() if k not in related_names} - return self_fields - - @classmethod - def _extract_related_names(cls) -> Set: - related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel - ): - related_names.add(name) - return related_names - - def _extract_model_db_fields(self) -> Dict: - self_fields = self._extract_own_model_fields() - self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns - } - for field in self._extract_related_names(): - if getattr(self, field) is not None: - self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ - ) - return self_fields - - @classmethod - def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: - merged_rows = [] - for index, model in enumerate(result_rows): - if index > 0 and model.pk == result_rows[index - 1].pk: - result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) - else: - merged_rows.append(model) - return merged_rows - - @classmethod - def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": - for field in one.__model_fields__.keys(): - # print(field, one.dict(), other.dict()) - if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), Model - ): - setattr(other, field, getattr(one, field) + getattr(other, field)) - elif isinstance(getattr(one, field), Model): - if getattr(one, field).pk == getattr(other, field).pk: - setattr( - other, - field, - cls.merge_two_instances( - getattr(one, field), getattr(other, field) - ), - ) - return other - - -class Model(FakePydantic): - __abstract__ = True - - objects = qry.QuerySet() - - @classmethod - def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, - ) -> "Model": - - item = {} - select_related = select_related or [] - - table_prefix = cls._orm_relationship_manager.resolve_relation_join( - previous_table, cls.__table__.name - ) - previous_table = cls.__table__.name - for related in select_related: - if "__" in related: - first_part, remainder = related.split("__", 1) - model_cls = cls.__model_fields__[first_part].to - child = model_cls.from_row( - row, select_related=[remainder], previous_table=previous_table - ) - item[first_part] = child - else: - model_cls = cls.__model_fields__[related].to - child = model_cls.from_row(row, previous_table=previous_table) - item[related] = child - - for column in cls.__table__.columns: - if column.name not in item: - item[column.name] = row[ - f'{table_prefix + "_" if table_prefix else ""}{column.name}' - ] - - return cls(**item) - - @property - def pk(self) -> str: - return getattr(self.values, self.__pkname__) - - @pk.setter - def pk(self, value: Any) -> None: - setattr(self.values, self.__pkname__, value) - - async def save(self) -> int: - self_fields = self._extract_model_db_fields() - if self.__model_fields__.get(self.__pkname__).autoincrement: - self_fields.pop(self.__pkname__, None) - expr = self.__table__.insert() - expr = expr.values(**self_fields) - item_id = await self.__database__.execute(expr) - self.pk = item_id - return item_id - - async def update(self, **kwargs: Any) -> int: - if kwargs: - new_values = {**self.dict(), **kwargs} - self.from_dict(new_values) - - self_fields = self._extract_model_db_fields() - self_fields.pop(self.__pkname__) - expr = ( - self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) - ) - result = await self.__database__.execute(expr) - return result - - async def delete(self) -> int: - expr = self.__table__.delete() - expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) - result = await self.__database__.execute(expr) - return result - - async def load(self) -> "Model": - expr = self.__table__.select().where(self.pk_column == self.pk) - row = await self.__database__.fetch_one(expr) - self.from_dict(dict(row)) - return self diff --git a/orm/models/__init__.py b/orm/models/__init__.py new file mode 100644 index 0000000..00fac17 --- /dev/null +++ b/orm/models/__init__.py @@ -0,0 +1,5 @@ +from orm.models.model import Model + +__all__ = [ + "Model" +] diff --git a/orm/models/fakepydantic.py b/orm/models/fakepydantic.py new file mode 100644 index 0000000..bbced86 --- /dev/null +++ b/orm/models/fakepydantic.py @@ -0,0 +1,195 @@ +import inspect +import json +import uuid +from typing import TYPE_CHECKING, Dict, TypeVar, Type, Any, Optional, Callable, Set, List + +import databases +import pydantic +import sqlalchemy +from pydantic import BaseModel + +import orm +from orm.fields import BaseField +from orm.models.metaclass import ModelMetaclass +from orm.relations import RelationshipManager + +if TYPE_CHECKING: #pragma no cover + from orm.models.model import Model + + +class FakePydantic(list, metaclass=ModelMetaclass): + # FakePydantic inherits from list in order to be treated as + # request.Body parameter in fastapi routes, + # inheriting from pydantic.BaseModel causes metaclass conflicts + __abstract__ = True + if TYPE_CHECKING: # pragma no cover + __model_fields__: Dict[str, TypeVar[BaseField]] + __table__: sqlalchemy.Table + __fields__: Dict[str, pydantic.fields.ModelField] + __pydantic_model__: Type[BaseModel] + __pkname__: str + __tablename__: str + __metadata__: sqlalchemy.MetaData + __database__: databases.Database + _orm_relationship_manager: RelationshipManager + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + self._orm_id: str = uuid.uuid4().hex + self._orm_saved: bool = False + self.values: Optional[BaseModel] = None + + if "pk" in kwargs: + kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs = { + k: self.__model_fields__[k].expand_relationship(v, self) + for k, v in kwargs.items() + } + self.values = self.__pydantic_model__(**kwargs) + + def __del__(self) -> None: + self._orm_relationship_manager.deregister(self) + + def __setattr__(self, key: str, value: Any) -> None: + if key in self.__fields__: + if self._is_conversion_to_json_needed(key) and not isinstance(value, str): + try: + value = json.dumps(value) + except TypeError: # pragma no cover + pass + + value = self.__model_fields__[key].expand_relationship(value, self) + + relation_key = self.__class__.__name__.title() + "_" + key + if not self._orm_relationship_manager.contains(relation_key, self): + setattr(self.values, key, value) + else: + super().__setattr__(key, value) + + def __getattribute__(self, key: str) -> Any: + if key != "__fields__" and key in self.__fields__: + relation_key = self.__class__.__name__.title() + "_" + key + if self._orm_relationship_manager.contains(relation_key, self): + return self._orm_relationship_manager.get(relation_key, self) + + item = getattr(self.values, key, None) + if ( + item is not None + and self._is_conversion_to_json_needed(key) + and isinstance(item, str) + ): + try: + item = json.loads(item) + except TypeError: # pragma no cover + pass + return item + return super().__getattribute__(key) + + def __eq__(self, other: "Model") -> bool: + return self.values.dict() == other.values.dict() + + def __same__(self, other: "Model") -> bool: + if self.__class__ != other.__class__: # pragma no cover + return False + return self._orm_id == other._orm_id or ( + self.values is not None and other.values is not None and self.pk == other.pk + ) + + def __repr__(self) -> str: # pragma no cover + return self.values.__repr__() + + @classmethod + def __get_validators__(cls) -> Callable: # pragma no cover + yield cls.__pydantic_model__.validate + + @classmethod + def get_name(cls, title: bool = False, lower: bool = True) -> str: + name = cls.__name__ + if lower: + name = name.lower() + if title: + name = name.title() + return name + + @property + def pk_column(self) -> sqlalchemy.Column: + return self.__table__.primary_key.columns.values()[0] + + @classmethod + def pk_type(cls) -> Any: + return cls.__model_fields__[cls.__pkname__].__type__ + + def dict(self) -> Dict: # noqa: A003 + dict_instance = self.values.dict() + for field in self._extract_related_names(): + nested_model = getattr(self, field) + if isinstance(nested_model, list): + dict_instance[field] = [x.dict() for x in nested_model] + else: + dict_instance[field] = ( + nested_model.dict() if nested_model is not None else {} + ) + return dict_instance + + def from_dict(self, value_dict: Dict) -> None: + for key, value in value_dict.items(): + setattr(self, key, value) + + def _is_conversion_to_json_needed(self, column_name: str) -> bool: + return self.__model_fields__.get(column_name).__type__ == pydantic.Json + + def _extract_own_model_fields(self) -> Dict: + related_names = self._extract_related_names() + self_fields = {k: v for k, v in self.dict().items() if k not in related_names} + return self_fields + + @classmethod + def _extract_related_names(cls) -> Set: + related_names = set() + for name, field in cls.__fields__.items(): + if inspect.isclass(field.type_) and issubclass( + field.type_, pydantic.BaseModel + ): + related_names.add(name) + return related_names + + def _extract_model_db_fields(self) -> Dict: + self_fields = self._extract_own_model_fields() + self_fields = { + k: v for k, v in self_fields.items() if k in self.__table__.columns + } + for field in self._extract_related_names(): + if getattr(self, field) is not None: + self_fields[field] = getattr( + getattr(self, field), self.__model_fields__[field].to.__pkname__ + ) + return self_fields + + @classmethod + def merge_instances_list(cls, result_rows: List["Model"]) -> List["Model"]: + merged_rows = [] + for index, model in enumerate(result_rows): + if index > 0 and model.pk == result_rows[index - 1].pk: + result_rows[-1] = cls.merge_two_instances(model, merged_rows[-1]) + else: + merged_rows.append(model) + return merged_rows + + @classmethod + def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": + for field in one.__model_fields__.keys(): + # print(field, one.dict(), other.dict()) + if isinstance(getattr(one, field), list) and not isinstance( + getattr(one, field), orm.Model + ): + setattr(other, field, getattr(one, field) + getattr(other, field)) + elif isinstance(getattr(one, field), orm.Model): + if getattr(one, field).pk == getattr(other, field).pk: + setattr( + other, + field, + cls.merge_two_instances( + getattr(one, field), getattr(other, field) + ), + ) + return other diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py new file mode 100644 index 0000000..b3546bb --- /dev/null +++ b/orm/models/metaclass.py @@ -0,0 +1,132 @@ +import copy +from typing import Dict, Tuple, Type, Optional, List, Any + +import sqlalchemy +from pydantic import BaseConfig, create_model +from pydantic.fields import ModelField + +from orm import ForeignKey, ModelDefinitionError +from orm.fields import BaseField +from orm.relations import RelationshipManager + +relationship_manager = RelationshipManager() + + +def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: + pydantic_fields = { + field_name: ( + base_field.__type__, + ... if base_field.is_required else base_field.default_value, + ) + for field_name, base_field in object_dict.items() + if isinstance(base_field, BaseField) + } + return pydantic_fields + + +def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: + child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s" + reverse_name = field.related_name or child_relation_name + relation_name = name.lower().title() + "_" + field.to.get_name() + relationship_manager.add_relation_type( + relation_name, reverse_name, field, table_name + ) + + +def expand_reverse_relationships(model: Type["Model"]) -> None: + for model_field in model.__model_fields__.values(): + if isinstance(model_field, ForeignKey): + child_model_name = model_field.related_name or model.get_name() + "s" + parent_model = model_field.to + child = model + if ( + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ + ): + register_reverse_model_fields(parent_model, child, child_model_name) + + +def register_reverse_model_fields( + model: Type["Model"], child: Type["Model"], child_model_name: str +) -> None: + model.__fields__[child_model_name] = ModelField( + name=child_model_name, + type_=Optional[child.__pydantic_model__], + model_config=child.__pydantic_model__.__config__, + class_validators=child.__pydantic_model__.__validators__, + ) + model.__model_fields__[child_model_name] = ForeignKey( + child, name=child_model_name, virtual=True + ) + + +def sqlalchemy_columns_from_model_fields( + name: str, object_dict: Dict, table_name: str +) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: + pkname: Optional[str] = None + columns: List[sqlalchemy.Column] = [] + model_fields: Dict[str, BaseField] = {} + + for field_name, field in object_dict.items(): + if isinstance(field, BaseField): + model_fields[field_name] = field + if not field.pydantic_only: + if field.primary_key: + pkname = field_name + if isinstance(field, ForeignKey): + register_relation_on_build(table_name, field, name) + columns.append(field.get_column(field_name)) + return pkname, columns, model_fields + + +def get_pydantic_base_orm_config() -> Type[BaseConfig]: + class Config(BaseConfig): + orm_mode = True + + return Config + + +class ModelMetaclass(type): + def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + + if attrs.get("__abstract__"): + return new_model + + tablename = attrs["__tablename__"] + metadata = attrs["__metadata__"] + + # sqlalchemy table creation + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, tablename + ) + attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) + attrs["__columns__"] = columns + attrs["__pkname__"] = pkname + + if not pkname: + raise ModelDefinitionError("Table has to have a primary key.") + + # pydantic model creation + pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **pydantic_fields + ) + attrs["__pydantic_fields__"] = pydantic_fields + attrs["__pydantic_model__"] = pydantic_model + attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) + attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) + attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + + attrs["__model_fields__"] = model_fields + attrs["_orm_relationship_manager"] = relationship_manager + + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + + expand_reverse_relationships(new_model) + + return new_model \ No newline at end of file diff --git a/orm/models/model.py b/orm/models/model.py new file mode 100644 index 0000000..d949d95 --- /dev/null +++ b/orm/models/model.py @@ -0,0 +1,93 @@ +from typing import List, Any + +import sqlalchemy + +import orm.queryset.queryset +from orm.models.fakepydantic import FakePydantic + + +class Model(FakePydantic): + __abstract__ = True + + objects = orm.queryset.queryset.QuerySet() + + @classmethod + def from_row( + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, + ) -> "Model": + + item = {} + select_related = select_related or [] + + table_prefix = cls._orm_relationship_manager.resolve_relation_join( + previous_table, cls.__table__.name + ) + previous_table = cls.__table__.name + for related in select_related: + if "__" in related: + first_part, remainder = related.split("__", 1) + model_cls = cls.__model_fields__[first_part].to + child = model_cls.from_row( + row, select_related=[remainder], previous_table=previous_table + ) + item[first_part] = child + else: + model_cls = cls.__model_fields__[related].to + child = model_cls.from_row(row, previous_table=previous_table) + item[related] = child + + for column in cls.__table__.columns: + if column.name not in item: + item[column.name] = row[ + f'{table_prefix + "_" if table_prefix else ""}{column.name}' + ] + + return cls(**item) + + @property + def pk(self) -> str: + return getattr(self.values, self.__pkname__) + + @pk.setter + def pk(self, value: Any) -> None: + setattr(self.values, self.__pkname__, value) + + async def save(self) -> int: + self_fields = self._extract_model_db_fields() + if self.__model_fields__.get(self.__pkname__).autoincrement: + self_fields.pop(self.__pkname__, None) + expr = self.__table__.insert() + expr = expr.values(**self_fields) + item_id = await self.__database__.execute(expr) + self.pk = item_id + return item_id + + async def update(self, **kwargs: Any) -> int: + if kwargs: + new_values = {**self.dict(), **kwargs} + self.from_dict(new_values) + + self_fields = self._extract_model_db_fields() + self_fields.pop(self.__pkname__) + expr = ( + self.__table__.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.__pkname__)) + ) + result = await self.__database__.execute(expr) + return result + + async def delete(self) -> int: + expr = self.__table__.delete() + expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) + result = await self.__database__.execute(expr) + return result + + async def load(self) -> "Model": + expr = self.__table__.select().where(self.pk_column == self.pk) + row = await self.__database__.fetch_one(expr) + self.from_dict(dict(row)) + return self \ No newline at end of file diff --git a/orm/queryset.py b/orm/queryset.py deleted file mode 100644 index 5033f13..0000000 --- a/orm/queryset.py +++ /dev/null @@ -1,571 +0,0 @@ -from typing import ( - Any, - Dict, - List, - NamedTuple, - Optional, - TYPE_CHECKING, - Tuple, - Type, - Union, -) - -import databases -import sqlalchemy -from sqlalchemy import text - -import orm # noqa I100 -import orm.fields.foreign_key -from orm import ForeignKey -from orm.exceptions import MultipleMatches, NoMatch, QueryDefinitionError -from orm.fields.base import BaseField - -if TYPE_CHECKING: # pragma no cover - from orm.models import Model - -FILTER_OPERATORS = { - "exact": "__eq__", - "iexact": "ilike", - "contains": "like", - "icontains": "ilike", - "in": "in_", - "gt": "__gt__", - "gte": "__ge__", - "lt": "__lt__", - "lte": "__le__", -} - -ESCAPE_CHARACTERS = ["%", "_"] - - -class JoinParameters(NamedTuple): - prev_model: Type["Model"] - previous_alias: str - from_table: str - model_cls: Type["Model"] - - -class Query: - def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, - ) -> None: - - self.query_offset = offset - self.limit_count = limit_count - self._select_related = select_related - self.filter_clauses = filter_clauses - - self.model_cls = model_cls - self.table = self.model_cls.__table__ - - self.auto_related = [] - self.used_aliases = [] - - self.select_from = None - self.columns = None - self.order_bys = None - - def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: - self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] - self.select_from = self.table - - for key in self.model_cls.__model_fields__: - if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], - orm.fields.foreign_key.ForeignKey, - ) - and key not in self._select_related - ): - self._select_related = [key] + self._select_related - - start_params = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - self._extract_auto_required_relations(prev_model=start_params.prev_model) - self._include_auto_related_models() - self._select_related.sort(key=lambda item: (-len(item), item)) - - for item in self._select_related: - join_parameters = JoinParameters( - self.model_cls, "", self.table.name, self.model_cls - ) - - for part in item.split("__"): - join_parameters = self._build_join_parameters(part, join_parameters) - - expr = sqlalchemy.sql.select(self.columns) - expr = expr.select_from(self.select_from) - - expr = self._apply_expression_modifiers(expr) - - # print(expr.compile(compile_kwargs={"literal_binds": True})) - self._reset_query_parameters() - - return expr, self._select_related - - @staticmethod - def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: - return [ - text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") - for column in table.columns - ] - - @staticmethod - def prefixed_table_name(alias: str, name: str) -> text: - return text(f"{name} {alias}_{name}") - - @staticmethod - def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str - ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part - - def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str - ) -> bool: - prev_part_of_related = "__".join(rel_part.split("__")[:-1]) - partial_match = any( - [x.startswith(prev_part_of_related) for x in self._select_related] - ) - already_checked = any([x.startswith(rel_part) for x in self.auto_related]) - return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested - - def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, - ) -> text: - left_part = f"{alias}_{to_clause}" - right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" - return text(f"{left_part}={right_part}") - - def _build_join_parameters( - self, part: str, join_params: JoinParameters - ) -> JoinParameters: - model_cls = join_params.model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name - - alias = model_cls._orm_relationship_manager.resolve_relation_join( - join_params.from_table, to_table - ) - if alias not in self.used_aliases: - if join_params.prev_model.__model_fields__[part].virtual: - to_key = next( - ( - v - for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model - ), - None, - ).name - from_key = model_cls.__pkname__ - else: - to_key = model_cls.__pkname__ - from_key = part - - on_clause = self.on_clause( - previous_alias=join_params.previous_alias, - alias=alias, - from_clause=f"{join_params.from_table}.{from_key}", - to_clause=f"{to_table}.{to_key}", - ) - target_table = self.prefixed_table_name(alias, to_table) - self.select_from = sqlalchemy.sql.outerjoin( - self.select_from, target_table, on_clause - ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) - self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) - self.used_aliases.append(alias) - - previous_alias = alias - from_table = to_table - prev_model = model_cls - return JoinParameters(prev_model, previous_alias, from_table, model_cls) - - def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, - ) -> None: - for field_name, field in prev_model.__model_fields__.items(): - if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part - ): - rel_part = field_name if not rel_part else rel_part + "__" + field_name - if not field.nullable: - if rel_part not in self._select_related: - self.auto_related.append("__".join(rel_part.split("__")[:-1])) - rel_part = "" - elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part - ): - self._extract_auto_required_relations( - prev_model=field.to, - rel_part=rel_part, - nested=True, - parent_virtual=field.virtual, - ) - else: - rel_part = "" - - def _include_auto_related_models(self) -> None: - if self.auto_related: - new_joins = [] - for join in self._select_related: - if not any([x.startswith(join) for x in self.auto_related]): - new_joins.append(join) - self._select_related = new_joins + self.auto_related - - def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select - ) -> sqlalchemy.sql.select: - if self.filter_clauses: - if len(self.filter_clauses) == 1: - clause = self.filter_clauses[0] - else: - clause = sqlalchemy.sql.and_(*self.filter_clauses) - expr = expr.where(clause) - - if self.limit_count: - expr = expr.limit(self.limit_count) - - if self.query_offset: - expr = expr.offset(self.query_offset) - - for order in self.order_bys: - expr = expr.order_by(order) - return expr - - def _reset_query_parameters(self) -> None: - self.select_from = None - self.columns = None - self.order_bys = None - self.auto_related = [] - self.used_aliases = [] - - -class QueryClause: - def __init__( - self, model_cls: Type["Model"], filter_clauses: List, select_related: List, - ) -> None: - - self._select_related = select_related - self.filter_clauses = filter_clauses - - self.model_cls = model_cls - self.table = self.model_cls.__table__ - - def filter( # noqa: A003 - self, **kwargs: Any - ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: - filter_clauses = self.filter_clauses - select_related = list(self._select_related) - - if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ - kwargs[pk_name] = kwargs.pop("pk") - - for key, value in kwargs.items(): - table_prefix = "" - if "__" in key: - parts = key.split("__") - - ( - op, - field_name, - related_parts, - ) = self._extract_operator_field_and_related(parts) - - model_cls = self.model_cls - if related_parts: - ( - select_related, - table_prefix, - model_cls, - ) = self._determine_filter_target_table( - related_parts, select_related - ) - - table = model_cls.__table__ - column = model_cls.__table__.columns[field_name] - - else: - op = "exact" - column = self.table.columns[key] - table = self.table - - value, has_escaped_character = self._escape_characters_in_clause(op, value) - - if isinstance(value, orm.Model): - value = value.pk - - op_attr = FILTER_OPERATORS[op] - clause = getattr(column, op_attr)(value) - clause = self._compile_clause( - clause, - column, - table, - table_prefix, - modifiers={"escape": "\\" if has_escaped_character else None}, - ) - filter_clauses.append(clause) - - return filter_clauses, select_related - - def _determine_filter_target_table( - self, related_parts: List[str], select_related: List[str] - ) -> Tuple[List[str], str, "Model"]: - - table_prefix = "" - model_cls = self.model_cls - select_related = [relation for relation in select_related] - - # Add any implied select_related - related_str = "__".join(related_parts) - if related_str not in select_related: - select_related.append(related_str) - - # Walk the relationships to the actual model class - # against which the comparison is being made. - previous_table = model_cls.__tablename__ - for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager - table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to - previous_table = current_table - return select_related, table_prefix, model_cls - - def _compile_clause( - self, - clause: sqlalchemy.sql.expression.BinaryExpression, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - modifiers: Dict, - ) -> sqlalchemy.sql.expression.TextClause: - for modifier, modifier_value in modifiers.items(): - clause.modifiers[modifier] = modifier_value - - clause_text = str( - clause.compile( - dialect=self.model_cls.__database__._backend._dialect, - compile_kwargs={"literal_binds": True}, - ) - ) - alias = f"{table_prefix}_" if table_prefix else "" - aliased_name = f"{alias}{table.name}.{column.name}" - clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) - clause = text(clause_text) - return clause - - @staticmethod - def _escape_characters_in_clause( - op: str, value: Union[str, "Model"] - ) -> Tuple[str, bool]: - has_escaped_character = False - - if op in ["contains", "icontains"]: - if isinstance(value, orm.Model): - raise QueryDefinitionError( - "You cannot use contains and icontains with instance of the Model" - ) - - has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) - - if has_escaped_character: - # enable escape modifier - for char in ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - value = f"%{value}%" - - return value, has_escaped_character - - @staticmethod - def _extract_operator_field_and_related( - parts: List[str], - ) -> Tuple[str, str, Optional[List]]: - if parts[-1] in FILTER_OPERATORS: - op = parts[-1] - field_name = parts[-2] - related_parts = parts[:-2] - else: - op = "exact" - field_name = parts[-1] - related_parts = parts[:-1] - - return op, field_name, related_parts - - -class QuerySet: - def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, - ) -> None: - self.model_cls = model_cls - self.filter_clauses = [] if filter_clauses is None else filter_clauses - self._select_related = [] if select_related is None else select_related - self.limit_count = limit_count - self.query_offset = offset - self.order_bys = None - - def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": - return self.__class__(model_cls=owner) - - @property - def database(self) -> databases.Database: - return self.model_cls.__database__ - - @property - def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ - - def build_select_expression(self) -> sqlalchemy.sql.select: - qry = Query( - model_cls=self.model_cls, - select_related=self._select_related, - filter_clauses=self.filter_clauses, - offset=self.query_offset, - limit_count=self.limit_count, - ) - exp, self._select_related = qry.build_select_expression() - return exp - - def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 - qryclause = QueryClause( - model_cls=self.model_cls, - select_related=self._select_related, - filter_clauses=self.filter_clauses, - ) - filter_clauses, select_related = qryclause.filter(**kwargs) - - return self.__class__( - model_cls=self.model_cls, - filter_clauses=filter_clauses, - select_related=select_related, - limit_count=self.limit_count, - offset=self.query_offset, - ) - - def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": - if not isinstance(related, (list, tuple)): - related = [related] - - related = list(self._select_related) + related - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=related, - limit_count=self.limit_count, - offset=self.query_offset, - ) - - async def exists(self) -> bool: - expr = self.build_select_expression() - expr = sqlalchemy.exists(expr).select() - return await self.database.fetch_val(expr) - - async def count(self) -> int: - expr = self.build_select_expression().alias("subquery_for_count") - expr = sqlalchemy.func.count().select().select_from(expr) - return await self.database.fetch_val(expr) - - def limit(self, limit_count: int) -> "QuerySet": - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=limit_count, - offset=self.query_offset, - ) - - def offset(self, offset: int) -> "QuerySet": - return self.__class__( - model_cls=self.model_cls, - filter_clauses=self.filter_clauses, - select_related=self._select_related, - limit_count=self.limit_count, - offset=offset, - ) - - async def first(self, **kwargs: Any) -> "Model": - if kwargs: - return await self.filter(**kwargs).first() - - rows = await self.limit(1).all() - if rows: - return rows[0] - - async def get(self, **kwargs: Any) -> "Model": - if kwargs: - return await self.filter(**kwargs).get() - - expr = self.build_select_expression().limit(2) - rows = await self.database.fetch_all(expr) - - if not rows: - raise NoMatch() - if len(rows) > 1: - raise MultipleMatches() - return self.model_cls.from_row(rows[0], select_related=self._select_related) - - async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 - if kwargs: - return await self.filter(**kwargs).all() - - expr = self.build_select_expression() - rows = await self.database.fetch_all(expr) - result_rows = [ - self.model_cls.from_row(row, select_related=self._select_related) - for row in rows - ] - - result_rows = self.model_cls.merge_instances_list(result_rows) - - return result_rows - - async def create(self, **kwargs: Any) -> "Model": - - new_kwargs = dict(**kwargs) - - # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] - if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) - ): - del new_kwargs[pkname] - - # substitute related models with their pk - for field in self.model_cls._extract_related_names(): - if field in new_kwargs and new_kwargs.get(field) is not None: - new_kwargs[field] = getattr( - new_kwargs.get(field), - self.model_cls.__model_fields__[field].to.__pkname__, - ) - - # Build the insert expression. - expr = self.table.insert() - expr = expr.values(**new_kwargs) - - # Execute the insert, and return a new model instance. - instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) - return instance diff --git a/orm/queryset/__init__.py b/orm/queryset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/orm/queryset/clause.py b/orm/queryset/clause.py new file mode 100644 index 0000000..0cc2309 --- /dev/null +++ b/orm/queryset/clause.py @@ -0,0 +1,176 @@ +from typing import Type, List, Any, Tuple, Dict, Union, Optional, TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import text + +import orm +from orm.exceptions import QueryDefinitionError + +if TYPE_CHECKING: # pragma no cover + from orm import Model + +FILTER_OPERATORS = { + "exact": "__eq__", + "iexact": "ilike", + "contains": "like", + "icontains": "ilike", + "in": "in_", + "gt": "__gt__", + "gte": "__ge__", + "lt": "__lt__", + "lte": "__le__", +} +ESCAPE_CHARACTERS = ["%", "_"] + + +class QueryClause: + def __init__( + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + ) -> None: + + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.__table__ + + def filter( # noqa: A003 + self, **kwargs: Any + ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: + filter_clauses = self.filter_clauses + select_related = list(self._select_related) + + if kwargs.get("pk"): + pk_name = self.model_cls.__pkname__ + kwargs[pk_name] = kwargs.pop("pk") + + for key, value in kwargs.items(): + table_prefix = "" + if "__" in key: + parts = key.split("__") + + ( + op, + field_name, + related_parts, + ) = self._extract_operator_field_and_related(parts) + + model_cls = self.model_cls + if related_parts: + ( + select_related, + table_prefix, + model_cls, + ) = self._determine_filter_target_table( + related_parts, select_related + ) + + table = model_cls.__table__ + column = model_cls.__table__.columns[field_name] + + else: + op = "exact" + column = self.table.columns[key] + table = self.table + + value, has_escaped_character = self._escape_characters_in_clause(op, value) + + if isinstance(value, orm.Model): + value = value.pk + + op_attr = FILTER_OPERATORS[op] + clause = getattr(column, op_attr)(value) + clause = self._compile_clause( + clause, + column, + table, + table_prefix, + modifiers={"escape": "\\" if has_escaped_character else None}, + ) + filter_clauses.append(clause) + + return filter_clauses, select_related + + def _determine_filter_target_table( + self, related_parts: List[str], select_related: List[str] + ) -> Tuple[List[str], str, "Model"]: + + table_prefix = "" + model_cls = self.model_cls + select_related = [relation for relation in select_related] + + # Add any implied select_related + related_str = "__".join(related_parts) + if related_str not in select_related: + select_related.append(related_str) + + # Walk the relationships to the actual model class + # against which the comparison is being made. + previous_table = model_cls.__tablename__ + for part in related_parts: + current_table = model_cls.__model_fields__[part].to.__tablename__ + manager = model_cls._orm_relationship_manager + table_prefix = manager.resolve_relation_join(previous_table, current_table) + model_cls = model_cls.__model_fields__[part].to + previous_table = current_table + return select_related, table_prefix, model_cls + + def _compile_clause( + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, + ) -> sqlalchemy.sql.expression.TextClause: + for modifier, modifier_value in modifiers.items(): + clause.modifiers[modifier] = modifier_value + + clause_text = str( + clause.compile( + dialect=self.model_cls.__database__._backend._dialect, + compile_kwargs={"literal_binds": True}, + ) + ) + alias = f"{table_prefix}_" if table_prefix else "" + aliased_name = f"{alias}{table.name}.{column.name}" + clause_text = clause_text.replace(f"{table.name}.{column.name}", aliased_name) + clause = text(clause_text) + return clause + + @staticmethod + def _escape_characters_in_clause( + op: str, value: Union[str, "Model"] + ) -> Tuple[str, bool]: + has_escaped_character = False + + if op in ["contains", "icontains"]: + if isinstance(value, orm.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) + + has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" + + return value, has_escaped_character + + @staticmethod + def _extract_operator_field_and_related( + parts: List[str], + ) -> Tuple[str, str, Optional[List]]: + if parts[-1] in FILTER_OPERATORS: + op = parts[-1] + field_name = parts[-2] + related_parts = parts[:-2] + else: + op = "exact" + field_name = parts[-1] + related_parts = parts[:-1] + + return op, field_name, related_parts diff --git a/orm/queryset/query.py b/orm/queryset/query.py new file mode 100644 index 0000000..22b2db1 --- /dev/null +++ b/orm/queryset/query.py @@ -0,0 +1,228 @@ +from typing import NamedTuple, Type, List, Tuple, TYPE_CHECKING + +import sqlalchemy +from sqlalchemy import text + +import orm +from orm import ForeignKey +from orm.fields import BaseField + +if TYPE_CHECKING: # pragma no cover + from orm import Model + + +class JoinParameters(NamedTuple): + prev_model: Type["Model"] + previous_alias: str + from_table: str + model_cls: Type["Model"] + + +class Query: + def __init__( + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, + ) -> None: + + self.query_offset = offset + self.limit_count = limit_count + self._select_related = select_related + self.filter_clauses = filter_clauses + + self.model_cls = model_cls + self.table = self.model_cls.__table__ + + self.auto_related = [] + self.used_aliases = [] + + self.select_from = None + self.columns = None + self.order_bys = None + + def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: + self.columns = list(self.table.columns) + self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.select_from = self.table + + for key in self.model_cls.__model_fields__: + if ( + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], + orm.fields.foreign_key.ForeignKey, + ) + and key not in self._select_related + ): + self._select_related = [key] + self._select_related + + start_params = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + self._extract_auto_required_relations(prev_model=start_params.prev_model) + self._include_auto_related_models() + self._select_related.sort(key=lambda item: (-len(item), item)) + + for item in self._select_related: + join_parameters = JoinParameters( + self.model_cls, "", self.table.name, self.model_cls + ) + + for part in item.split("__"): + join_parameters = self._build_join_parameters(part, join_parameters) + + expr = sqlalchemy.sql.select(self.columns) + expr = expr.select_from(self.select_from) + + expr = self._apply_expression_modifiers(expr) + + # print(expr.compile(compile_kwargs={"literal_binds": True})) + self._reset_query_parameters() + + return expr, self._select_related + + @staticmethod + def prefixed_columns(alias: str, table: sqlalchemy.Table) -> List[text]: + return [ + text(f"{alias}_{table.name}.{column.name} as {alias}_{column.name}") + for column in table.columns + ] + + @staticmethod + def prefixed_table_name(alias: str, name: str) -> text: + return text(f"{name} {alias}_{name}") + + @staticmethod + def _field_is_a_foreign_key_and_no_circular_reference( + field: BaseField, field_name: str, rel_part: str + ) -> bool: + return isinstance(field, ForeignKey) and field_name not in rel_part + + def _field_qualifies_to_deeper_search( + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + ) -> bool: + prev_part_of_related = "__".join(rel_part.split("__")[:-1]) + partial_match = any( + [x.startswith(prev_part_of_related) for x in self._select_related] + ) + already_checked = any([x.startswith(rel_part) for x in self.auto_related]) + return ( + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested + + def on_clause( + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + ) -> text: + left_part = f"{alias}_{to_clause}" + right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" + return text(f"{left_part}={right_part}") + + def _build_join_parameters( + self, part: str, join_params: JoinParameters + ) -> JoinParameters: + model_cls = join_params.model_cls.__model_fields__[part].to + to_table = model_cls.__table__.name + + alias = model_cls._orm_relationship_manager.resolve_relation_join( + join_params.from_table, to_table + ) + if alias not in self.used_aliases: + if join_params.prev_model.__model_fields__[part].virtual: + to_key = next( + ( + v + for k, v in model_cls.__model_fields__.items() + if isinstance(v, ForeignKey) and v.to == join_params.prev_model + ), + None, + ).name + from_key = model_cls.__pkname__ + else: + to_key = model_cls.__pkname__ + from_key = part + + on_clause = self.on_clause( + previous_alias=join_params.previous_alias, + alias=alias, + from_clause=f"{join_params.from_table}.{from_key}", + to_clause=f"{to_table}.{to_key}", + ) + target_table = self.prefixed_table_name(alias, to_table) + self.select_from = sqlalchemy.sql.outerjoin( + self.select_from, target_table, on_clause + ) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) + self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.used_aliases.append(alias) + + previous_alias = alias + from_table = to_table + prev_model = model_cls + return JoinParameters(prev_model, previous_alias, from_table, model_cls) + + def _extract_auto_required_relations( + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, + ) -> None: + for field_name, field in prev_model.__model_fields__.items(): + if self._field_is_a_foreign_key_and_no_circular_reference( + field, field_name, rel_part + ): + rel_part = field_name if not rel_part else rel_part + "__" + field_name + if not field.nullable: + if rel_part not in self._select_related: + self.auto_related.append("__".join(rel_part.split("__")[:-1])) + rel_part = "" + elif self._field_qualifies_to_deeper_search( + field, parent_virtual, nested, rel_part + ): + self._extract_auto_required_relations( + prev_model=field.to, + rel_part=rel_part, + nested=True, + parent_virtual=field.virtual, + ) + else: + rel_part = "" + + def _include_auto_related_models(self) -> None: + if self.auto_related: + new_joins = [] + for join in self._select_related: + if not any([x.startswith(join) for x in self.auto_related]): + new_joins.append(join) + self._select_related = new_joins + self.auto_related + + def _apply_expression_modifiers( + self, expr: sqlalchemy.sql.select + ) -> sqlalchemy.sql.select: + if self.filter_clauses: + if len(self.filter_clauses) == 1: + clause = self.filter_clauses[0] + else: + clause = sqlalchemy.sql.and_(*self.filter_clauses) + expr = expr.where(clause) + + if self.limit_count: + expr = expr.limit(self.limit_count) + + if self.query_offset: + expr = expr.offset(self.query_offset) + + for order in self.order_bys: + expr = expr.order_by(order) + return expr + + def _reset_query_parameters(self) -> None: + self.select_from = None + self.columns = None + self.order_bys = None + self.auto_related = [] + self.used_aliases = [] diff --git a/orm/queryset/queryset.py b/orm/queryset/queryset.py new file mode 100644 index 0000000..a61a5fa --- /dev/null +++ b/orm/queryset/queryset.py @@ -0,0 +1,175 @@ +from typing import Type, List, Any, Union, Tuple, TYPE_CHECKING + +import databases +import sqlalchemy + +import orm # noqa I100 +from orm import NoMatch, MultipleMatches +from orm.queryset.clause import QueryClause +from orm.queryset.query import Query + +if TYPE_CHECKING: # pragma no cover + from orm import Model + + +class QuerySet: + def __init__( + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, + ) -> None: + self.model_cls = model_cls + self.filter_clauses = [] if filter_clauses is None else filter_clauses + self._select_related = [] if select_related is None else select_related + self.limit_count = limit_count + self.query_offset = offset + self.order_bys = None + + def __get__(self, instance: "QuerySet", owner: Type["Model"]) -> "QuerySet": + return self.__class__(model_cls=owner) + + @property + def database(self) -> databases.Database: + return self.model_cls.__database__ + + @property + def table(self) -> sqlalchemy.Table: + return self.model_cls.__table__ + + def build_select_expression(self) -> sqlalchemy.sql.select: + qry = Query( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + offset=self.query_offset, + limit_count=self.limit_count, + ) + exp, self._select_related = qry.build_select_expression() + return exp + + def filter(self, **kwargs: Any) -> "QuerySet": # noqa: A003 + qryclause = QueryClause( + model_cls=self.model_cls, + select_related=self._select_related, + filter_clauses=self.filter_clauses, + ) + filter_clauses, select_related = qryclause.filter(**kwargs) + + return self.__class__( + model_cls=self.model_cls, + filter_clauses=filter_clauses, + select_related=select_related, + limit_count=self.limit_count, + offset=self.query_offset, + ) + + def select_related(self, related: Union[List, Tuple, str]) -> "QuerySet": + if not isinstance(related, (list, tuple)): + related = [related] + + related = list(self._select_related) + related + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=related, + limit_count=self.limit_count, + offset=self.query_offset, + ) + + async def exists(self) -> bool: + expr = self.build_select_expression() + expr = sqlalchemy.exists(expr).select() + return await self.database.fetch_val(expr) + + async def count(self) -> int: + expr = self.build_select_expression().alias("subquery_for_count") + expr = sqlalchemy.func.count().select().select_from(expr) + return await self.database.fetch_val(expr) + + def limit(self, limit_count: int) -> "QuerySet": + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=limit_count, + offset=self.query_offset, + ) + + def offset(self, offset: int) -> "QuerySet": + return self.__class__( + model_cls=self.model_cls, + filter_clauses=self.filter_clauses, + select_related=self._select_related, + limit_count=self.limit_count, + offset=offset, + ) + + async def first(self, **kwargs: Any) -> "Model": + if kwargs: + return await self.filter(**kwargs).first() + + rows = await self.limit(1).all() + if rows: + return rows[0] + + async def get(self, **kwargs: Any) -> "Model": + if kwargs: + return await self.filter(**kwargs).get() + + expr = self.build_select_expression().limit(2) + rows = await self.database.fetch_all(expr) + + if not rows: + raise NoMatch() + if len(rows) > 1: + raise MultipleMatches() + return self.model_cls.from_row(rows[0], select_related=self._select_related) + + async def all(self, **kwargs: Any) -> List["Model"]: # noqa: A003 + if kwargs: + return await self.filter(**kwargs).all() + + expr = self.build_select_expression() + rows = await self.database.fetch_all(expr) + result_rows = [ + self.model_cls.from_row(row, select_related=self._select_related) + for row in rows + ] + + result_rows = self.model_cls.merge_instances_list(result_rows) + + return result_rows + + async def create(self, **kwargs: Any) -> "Model": + + new_kwargs = dict(**kwargs) + + # Remove primary key when None to prevent not null constraint in postgresql. + pkname = self.model_cls.__pkname__ + pk = self.model_cls.__model_fields__[pkname] + if ( + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) + ): + del new_kwargs[pkname] + + # substitute related models with their pk + for field in self.model_cls._extract_related_names(): + if field in new_kwargs and new_kwargs.get(field) is not None: + new_kwargs[field] = getattr( + new_kwargs.get(field), + self.model_cls.__model_fields__[field].to.__pkname__, + ) + + # Build the insert expression. + expr = self.table.insert() + expr = expr.values(**new_kwargs) + + # Execute the insert, and return a new model instance. + instance = self.model_cls(**kwargs) + instance.pk = await self.database.execute(expr) + return instance From 4aadc9fac62fa368505c011d65ba7b427b4fc2ed Mon Sep 17 00:00:00 2001 From: collerek Date: Tue, 11 Aug 2020 19:54:54 +0200 Subject: [PATCH 49/62] clean code --- .coverage | Bin 53248 -> 53248 bytes orm/models/__init__.py | 5 ++-- orm/models/fakepydantic.py | 28 ++++++++++++------ orm/models/metaclass.py | 9 ++++-- orm/models/model.py | 10 +++---- orm/queryset/__init__.py | 3 ++ orm/queryset/clause.py | 26 ++++++++--------- orm/queryset/query.py | 58 ++++++++++++++++++------------------- orm/queryset/queryset.py | 22 +++++++------- 9 files changed, 88 insertions(+), 73 deletions(-) diff --git a/.coverage b/.coverage index 2007c584c86b15bdaf4a5b56e47ec8e955810cb9..5fd9217ba91707e3a40b601b0d05cdd5b02396a6 100644 GIT binary patch delta 435 zcmZozz}&Ead4rigiJ?HFZ7WT%*{_p%_)unQ+frJn!J+oc(faf#7d<^QYxZ&MZH0gCGJlugh7`7!!m@BRNV^UmhIyDcKe4HUQI tzE>^C{Gs;k;@7{^Za?{V_wTj0fBC+}?Z5kvlZ^|gfRC$eb4!1X0|3RpoMZq1 delta 439 zcmZozz}&Ead4rigi?NlV&MQUn_7|xG{3kQXuRO$_pv4{0{YC!Kxs{Opa{PKb1IOn z!UcB%-(7N@jc=@ zwOOm7i!W3}h=q|;OQ`PXpWL1IU;d3R|6lcgo2r-~P*hK_Yzy6(e`^mq%f3Lm$`z>z&-G7`){6Kj({>i)hD+T7A#;{X7S C0F`C{ diff --git a/orm/models/__init__.py b/orm/models/__init__.py index 00fac17..b948f31 100644 --- a/orm/models/__init__.py +++ b/orm/models/__init__.py @@ -1,5 +1,4 @@ +from orm.models.fakepydantic import FakePydantic from orm.models.model import Model -__all__ = [ - "Model" -] +__all__ = ["FakePydantic", "Model"] diff --git a/orm/models/fakepydantic.py b/orm/models/fakepydantic.py index bbced86..55fe51d 100644 --- a/orm/models/fakepydantic.py +++ b/orm/models/fakepydantic.py @@ -1,19 +1,29 @@ import inspect import json import uuid -from typing import TYPE_CHECKING, Dict, TypeVar, Type, Any, Optional, Callable, Set, List +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Set, + TYPE_CHECKING, + Type, + TypeVar, +) import databases import pydantic import sqlalchemy from pydantic import BaseModel -import orm +import orm # noqa I100 from orm.fields import BaseField from orm.models.metaclass import ModelMetaclass from orm.relations import RelationshipManager -if TYPE_CHECKING: #pragma no cover +if TYPE_CHECKING: # pragma no cover from orm.models.model import Model @@ -74,9 +84,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): item = getattr(self.values, key, None) if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) + item is not None + and self._is_conversion_to_json_needed(key) + and isinstance(item, str) ): try: item = json.loads(item) @@ -92,7 +102,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.values is not None and other.values is not None and self.pk == other.pk ) def __repr__(self) -> str: # pragma no cover @@ -148,7 +158,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): related_names = set() for name, field in cls.__fields__.items(): if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + field.type_, pydantic.BaseModel ): related_names.add(name) return related_names @@ -180,7 +190,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): for field in one.__model_fields__.keys(): # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), orm.Model + getattr(one, field), orm.Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), orm.Model): diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py index b3546bb..4ddb21e 100644 --- a/orm/models/metaclass.py +++ b/orm/models/metaclass.py @@ -1,14 +1,17 @@ import copy -from typing import Dict, Tuple, Type, Optional, List, Any +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type import sqlalchemy from pydantic import BaseConfig, create_model from pydantic.fields import ModelField -from orm import ForeignKey, ModelDefinitionError +from orm import ForeignKey, ModelDefinitionError # noqa I100 from orm.fields import BaseField from orm.relations import RelationshipManager +if TYPE_CHECKING: # pragma no cover + from orm import Model + relationship_manager = RelationshipManager() @@ -129,4 +132,4 @@ class ModelMetaclass(type): expand_reverse_relationships(new_model) - return new_model \ No newline at end of file + return new_model diff --git a/orm/models/model.py b/orm/models/model.py index d949d95..ab21bfc 100644 --- a/orm/models/model.py +++ b/orm/models/model.py @@ -1,15 +1,15 @@ -from typing import List, Any +from typing import Any, List import sqlalchemy -import orm.queryset.queryset -from orm.models.fakepydantic import FakePydantic +import orm.queryset # noqa I100 +from orm.models import FakePydantic # noqa I100 class Model(FakePydantic): __abstract__ = True - objects = orm.queryset.queryset.QuerySet() + objects = orm.queryset.QuerySet() @classmethod def from_row( @@ -90,4 +90,4 @@ class Model(FakePydantic): expr = self.__table__.select().where(self.pk_column == self.pk) row = await self.__database__.fetch_one(expr) self.from_dict(dict(row)) - return self \ No newline at end of file + return self diff --git a/orm/queryset/__init__.py b/orm/queryset/__init__.py index e69de29..30e112a 100644 --- a/orm/queryset/__init__.py +++ b/orm/queryset/__init__.py @@ -0,0 +1,3 @@ +from orm.queryset.queryset import QuerySet + +__all__ = ["QuerySet"] diff --git a/orm/queryset/clause.py b/orm/queryset/clause.py index 0cc2309..f587da6 100644 --- a/orm/queryset/clause.py +++ b/orm/queryset/clause.py @@ -1,9 +1,9 @@ -from typing import Type, List, Any, Tuple, Dict, Union, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union import sqlalchemy from sqlalchemy import text -import orm +import orm # noqa I100 from orm.exceptions import QueryDefinitionError if TYPE_CHECKING: # pragma no cover @@ -25,7 +25,7 @@ ESCAPE_CHARACTERS = ["%", "_"] class QueryClause: def __init__( - self, model_cls: Type["Model"], filter_clauses: List, select_related: List, + self, model_cls: Type["Model"], filter_clauses: List, select_related: List, ) -> None: self._select_related = select_related @@ -35,7 +35,7 @@ class QueryClause: self.table = self.model_cls.__table__ def filter( # noqa: A003 - self, **kwargs: Any + self, **kwargs: Any ) -> Tuple[List[sqlalchemy.sql.expression.TextClause], List[str]]: filter_clauses = self.filter_clauses select_related = list(self._select_related) @@ -92,7 +92,7 @@ class QueryClause: return filter_clauses, select_related def _determine_filter_target_table( - self, related_parts: List[str], select_related: List[str] + self, related_parts: List[str], select_related: List[str] ) -> Tuple[List[str], str, "Model"]: table_prefix = "" @@ -116,12 +116,12 @@ class QueryClause: return select_related, table_prefix, model_cls def _compile_clause( - self, - clause: sqlalchemy.sql.expression.BinaryExpression, - column: sqlalchemy.Column, - table: sqlalchemy.Table, - table_prefix: str, - modifiers: Dict, + self, + clause: sqlalchemy.sql.expression.BinaryExpression, + column: sqlalchemy.Column, + table: sqlalchemy.Table, + table_prefix: str, + modifiers: Dict, ) -> sqlalchemy.sql.expression.TextClause: for modifier, modifier_value in modifiers.items(): clause.modifiers[modifier] = modifier_value @@ -140,7 +140,7 @@ class QueryClause: @staticmethod def _escape_characters_in_clause( - op: str, value: Union[str, "Model"] + op: str, value: Union[str, "Model"] ) -> Tuple[str, bool]: has_escaped_character = False @@ -162,7 +162,7 @@ class QueryClause: @staticmethod def _extract_operator_field_and_related( - parts: List[str], + parts: List[str], ) -> Tuple[str, str, Optional[List]]: if parts[-1] in FILTER_OPERATORS: op = parts[-1] diff --git a/orm/queryset/query.py b/orm/queryset/query.py index 22b2db1..9592561 100644 --- a/orm/queryset/query.py +++ b/orm/queryset/query.py @@ -1,9 +1,9 @@ -from typing import NamedTuple, Type, List, Tuple, TYPE_CHECKING +from typing import List, NamedTuple, TYPE_CHECKING, Tuple, Type import sqlalchemy from sqlalchemy import text -import orm +import orm # noqa I100 from orm import ForeignKey from orm.fields import BaseField @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -50,12 +50,12 @@ class Query: for key in self.model_cls.__model_fields__: if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], - orm.fields.foreign_key.ForeignKey, - ) - and key not in self._select_related + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], + orm.fields.foreign_key.ForeignKey, + ) + and key not in self._select_related ): self._select_related = [key] + self._select_related @@ -97,12 +97,12 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: BaseField, field_name: str, rel_part: str ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( @@ -110,19 +110,19 @@ class Query: ) already_checked = any([x.startswith(rel_part) for x in self.auto_related]) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to to_table = model_cls.__table__.name @@ -165,15 +165,15 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: for field_name, field in prev_model.__model_fields__.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: @@ -181,7 +181,7 @@ class Query: self.auto_related.append("__".join(rel_part.split("__")[:-1])) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): self._extract_auto_required_relations( prev_model=field.to, @@ -201,7 +201,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/orm/queryset/queryset.py b/orm/queryset/queryset.py index a61a5fa..5cbc3cd 100644 --- a/orm/queryset/queryset.py +++ b/orm/queryset/queryset.py @@ -1,10 +1,10 @@ -from typing import Type, List, Any, Union, Tuple, TYPE_CHECKING +from typing import Any, List, TYPE_CHECKING, Tuple, Type, Union import databases import sqlalchemy import orm # noqa I100 -from orm import NoMatch, MultipleMatches +from orm import MultipleMatches, NoMatch from orm.queryset.clause import QueryClause from orm.queryset.query import Query @@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -151,9 +151,9 @@ class QuerySet: pkname = self.model_cls.__pkname__ pk = self.model_cls.__model_fields__[pkname] if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) ): del new_kwargs[pkname] From dd20fd9f017e6d79779bf1f7067c31955744b114 Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 12 Aug 2020 10:33:37 +0200 Subject: [PATCH 50/62] refactors in metaclass --- .coverage | Bin 53248 -> 53248 bytes orm/models/fakepydantic.py | 40 ++++++++++++++++++++----------------- orm/models/metaclass.py | 26 +++++++++++++----------- orm/relations.py | 3 +-- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/.coverage b/.coverage index 5fd9217ba91707e3a40b601b0d05cdd5b02396a6..367e452b938b04ad2398c4ae448b921f91a0197e 100644 GIT binary patch delta 173 zcmV;e08;;epaX!Q1F$eJ3Nj!uFgi0bIx;u2IxoynG9LgB`48$3-Vf6c$Pcv-p%0G_ zgb#cVS`R@FD-ReC5DxPW)ef-^sScYCat>GyHVz-N5fJAMlOm4z0Sc2Fk04C>{;#`# zy?@&h7Yqae2`~(Hmwd~Q-|}z#_m|7F+0A{M0w4 None: if key in self.__fields__: - if self._is_conversion_to_json_needed(key) and not isinstance(value, str): - try: - value = json.dumps(value) - except TypeError: # pragma no cover - pass - + value = self._convert_json(key, value, op="dumps") value = self.__model_fields__[key].expand_relationship(value, self) - relation_key = self.__class__.__name__.title() + "_" + key + relation_key = self.get_name(title=True) + "_" + key if not self._orm_relationship_manager.contains(relation_key, self): setattr(self.values, key, value) else: @@ -78,20 +74,12 @@ class FakePydantic(list, metaclass=ModelMetaclass): def __getattribute__(self, key: str) -> Any: if key != "__fields__" and key in self.__fields__: - relation_key = self.__class__.__name__.title() + "_" + key + relation_key = self.get_name(title=True) + "_" + key if self._orm_relationship_manager.contains(relation_key, self): return self._orm_relationship_manager.get(relation_key, self) item = getattr(self.values, key, None) - if ( - item is not None - and self._is_conversion_to_json_needed(key) - and isinstance(item, str) - ): - try: - item = json.loads(item) - except TypeError: # pragma no cover - pass + item = self._convert_json(key, item, op="loads") return item return super().__getattribute__(key) @@ -145,6 +133,23 @@ class FakePydantic(list, metaclass=ModelMetaclass): for key, value in value_dict.items(): setattr(self, key, value) + def _convert_json(self, column_name: str, value: Any, op: str) -> Union[str, dict]: + + if not self._is_conversion_to_json_needed(column_name): + return value + + condition = ( + isinstance(value, str) if op == "loads" else not isinstance(value, str) + ) + operand = json.loads if op == "loads" else json.dumps + + if condition: + try: + return operand(value) + except TypeError: # pragma no cover + pass + return value + def _is_conversion_to_json_needed(self, column_name: str) -> bool: return self.__model_fields__.get(column_name).__type__ == pydantic.Json @@ -188,7 +193,6 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.__model_fields__.keys(): - # print(field, one.dict(), other.dict()) if isinstance(getattr(one, field), list) and not isinstance( getattr(one, field), orm.Model ): diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py index 4ddb21e..d10806b 100644 --- a/orm/models/metaclass.py +++ b/orm/models/metaclass.py @@ -66,19 +66,21 @@ def register_reverse_model_fields( def sqlalchemy_columns_from_model_fields( name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: - pkname: Optional[str] = None - columns: List[sqlalchemy.Column] = [] - model_fields: Dict[str, BaseField] = {} + columns = [] + pkname = None + model_fields = { + field_name: field + for field_name, field in object_dict.items() + if isinstance(field, BaseField) + } + for field_name, field in model_fields.items(): + if field.primary_key: + pkname = field_name + if not field.pydantic_only: + columns.append(field.get_column(field_name)) + if isinstance(field, ForeignKey): + register_relation_on_build(table_name, field, name) - for field_name, field in object_dict.items(): - if isinstance(field, BaseField): - model_fields[field_name] = field - if not field.pydantic_only: - if field.primary_key: - pkname = field_name - if isinstance(field, ForeignKey): - register_relation_on_build(table_name, field, name) - columns.append(field.get_column(field_name)) return pkname, columns, model_fields diff --git a/orm/relations.py b/orm/relations.py index 7cf5ecf..c7ef8b6 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -49,9 +49,8 @@ class RelationshipManager: ) def deregister(self, model: "FakePydantic") -> None: - # print(f'deregistering {model.__class__.__name__}, {model._orm_id}') for rel_type in self._relations.keys(): - if model.__class__.__name__.lower() in rel_type.lower(): + if model.get_name() in rel_type.lower(): if model._orm_id in self._relations[rel_type]: del self._relations[rel_type][model._orm_id] From 24eb0b30e725f957498d17bb3813691b339d12cf Mon Sep 17 00:00:00 2001 From: collerek Date: Wed, 12 Aug 2020 16:24:45 +0200 Subject: [PATCH 51/62] introduce docs -> models section mostly finished --- .coverage | Bin 53248 -> 53248 bytes docs/fastapi.md | 0 docs/fields.md | 0 docs/index.md | 17 ++++ docs/models.md | 171 +++++++++++++++++++++++++++++++++ docs/pydantic.md | 0 docs/queries.md | 0 docs/relations.md | 0 docs_src/models/docs001.py | 16 +++ docs_src/models/docs002.py | 19 ++++ docs_src/models/docs003.py | 33 +++++++ docs_src/models/docs004.py | 22 +++++ docs_src/models/docs005.py | 51 ++++++++++ docs_src/models/docs006.py | 41 ++++++++ docs_src/models/docs007.py | 22 +++++ mkdocs.yml | 29 ++++++ orm/fields/base.py | 19 ++-- orm/fields/decorators.py | 4 +- orm/models/metaclass.py | 5 +- orm/queryset/clause.py | 24 ++--- orm/queryset/query.py | 3 +- orm/relations.py | 51 +++------- tests/test_model_definition.py | 33 +++---- 23 files changed, 475 insertions(+), 85 deletions(-) create mode 100644 docs/fastapi.md create mode 100644 docs/fields.md create mode 100644 docs/index.md create mode 100644 docs/models.md create mode 100644 docs/pydantic.md create mode 100644 docs/queries.md create mode 100644 docs/relations.md create mode 100644 docs_src/models/docs001.py create mode 100644 docs_src/models/docs002.py create mode 100644 docs_src/models/docs003.py create mode 100644 docs_src/models/docs004.py create mode 100644 docs_src/models/docs005.py create mode 100644 docs_src/models/docs006.py create mode 100644 docs_src/models/docs007.py create mode 100644 mkdocs.yml diff --git a/.coverage b/.coverage index 367e452b938b04ad2398c4ae448b921f91a0197e..a41bf1cbb0cf75a863d6c855a07c6e51169c3201 100644 GIT binary patch delta 239 zcmVO>JQ!z(+|iGwhy8YkPn3qeGgj? zLJuzw91j!@_YT<(wGOQgpAL2oT@E`AB(o6^>O>JQ!z(+|iGwGW{Wj}L?od=FX= zK@Ter7!ME*^A6Pxu@0#Yn+|dgSPnK0AF~k<=M9r2jyW0_76btaG8Vea|Ih#Uob%u2 z<~e)z?Af#D50iF|O*GZ+Y^l$Ezwh_$ulC)&{{4OW`)#`)91#Qo2|5wl^1r<8yZ8Mw z*Z$R=<6r-t?|bU{cL9_4jyNwA4FmxRDh+A diff --git a/docs/fastapi.md b/docs/fastapi.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/fields.md b/docs/fields.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..000ea34 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,17 @@ +# Welcome to MkDocs + +For full documentation visit [mkdocs.org](https://www.mkdocs.org). + +## Commands + +* `mkdocs new [dir-name]` - Create a new project. +* `mkdocs serve` - Start the live-reloading docs server. +* `mkdocs build` - Build the documentation site. +* `mkdocs -h` - Print help message and exit. + +## Project layout + + mkdocs.yml # The configuration file. + docs/ + index.md # The documentation homepage. + ... # Other markdown pages, images and other files. diff --git a/docs/models.md b/docs/models.md new file mode 100644 index 0000000..e4f5a13 --- /dev/null +++ b/docs/models.md @@ -0,0 +1,171 @@ +# Models + +## Defining models +By defining an orm Model you get corresponding **Pydantic model** as well as **Sqlalchemy table** for free. +They are being managed in the background and you do not have to create them on your own. + +### Model Class +To build an ORM model you simply need to inherit a `orm.Model` class. + +```Python hl_lines="10" +--8<-- "../docs_src/models/docs001.py" +``` + +### Defining Fields +Next assign one or more of the [Fields][fields] as a class level variables. + +Each table **has to** have a primary key column, which you specify by setting `primary_key=True` on selected field. + +Only one primary key column is allowed. + +```Python hl_lines="14 15 16" +--8<-- "../docs_src/models/docs001.py" +``` + +!!! warning + Not assigning `primary_key` column or assigning more than one column per `Model` will raise `ModelDefinitionError` + exception. + +By default if you assign primary key to `Integer` field, the `autoincrement` option is set to true. + +You can disable by passing `autoincremant=False`. + +```Python +id = orm.Integer(primary_key=True, autoincrement=False) +``` + +Names of the fields will be used for both the underlying `pydantic` model and `sqlalchemy` table. + +### Dependencies + +Since orm depends on [`databases`][databases] and [`sqlalchemy-core`][sqlalchemy-core] for database connection +and table creation you need to assign each `Model` with two special parameters. + +#### Databases +One is `Database` instance created with your database url in [sqlalchemy connection string][sqlalchemy connection string] format. + +Created instance needs to be passed to every `Model` with `__database__` parameter. + +```Python hl_lines="1 6 11" +--8<-- "../docs_src/models/docs001.py" +``` + +!!! tip + You need to create the `Database` instance **only once** and use it for all models. + You can create several ones if you want to use multiple databases. + +#### Sqlalchemy +Second dependency is sqlalchemy `MetaData` instance. + +Created instance needs to be passed to every `Model` with `__metadata__` parameter. + +```Python hl_lines="2 7 12" +--8<-- "../docs_src/models/docs001.py" +``` + +!!! tip + You need to create the `MetaData` instance **only once** and use it for all models. + You can create several ones if you want to use multiple databases. + +### Table Names + +By default table name is created from Model class name as lowercase name plus 's'. + +You can overwrite this parameter by providing `__tablename__` argument. + +```Python hl_lines="11 12 13" +--8<-- "../docs_src/models/docs002.py" +``` + +## Initialization + +There are two ways to create and persist the `Model` instance in the database. + +!!!tip + Use `ipython` to try this from the console, since it supports `await`. + +If you plan to modify the instance in the later execution of your program you can initiate your `Model` as a normal class and later await a `save()` call. + +```Python hl_lines="19 20" +--8<-- "../docs_src/models/docs007.py" +``` + +If you want to initiate your `Model` and at the same time save in in the database use a QuerySet's method `create()`. + +Each model has a `QuerySet` initialised as `objects` parameter + +```Python hl_lines="22" +--8<-- "../docs_src/models/docs007.py" +``` + +!!!info + To read more about `QuerySets` and available methods visit [queries][queries] + +## Attributes Delegation + +Each call to `Model` fields parameter under the hood is delegated to either the `pydantic` model +or other related `Model` in case of relations. + +The fields and relations are not stored on the `Model` itself + +```Python hl_lines="31 32 33 34 35 36 37 38 39 40 41" +--8<-- "../docs_src/models/docs006.py" +``` + +!!! warning + In example above model instances are created but not persisted that's why `id` of `department` is None! + +!!!info + To read more about `ForeignKeys` and `Model` relations visit [relations][relations] + +## Internals + +Apart from special parameters defined in the `Model` during definition (tablename, metadata etc.) the `Model` provides you with useful internals. + +### Pydantic Model +To access auto created pydantic model you can use `Model.__pydantic_model__` parameter + +For example to list model fields you can: + +```Python hl_lines="18" +--8<-- "../docs_src/models/docs003.py" +``` + +!!!tip + Note how the primary key `id` field is optional as `Integer` primary key by default has `autoincrement` set to `True`. + +!!!info + For more options visit official [pydantic][pydantic] documentation. + +### Sqlalchemy Table +To access auto created sqlalchemy table you can use `Model.__table__` parameter + +For example to list table columns you can: + +```Python hl_lines="18" +--8<-- "../docs_src/models/docs004.py" +``` + +!!!tip + You can access table primary key name by `Course.__pkname__` + +!!!info + For more options visit official [sqlalchemy-metadata][sqlalchemy-metadata] documentation. + +### Fields Definition +To access orm `Fields` you can use `Model.__model_fields__` parameter + +For example to list table model fields you can: + +```Python hl_lines="18" +--8<-- "../docs_src/models/docs005.py" +``` + +[fields]: ./fields.md +[relations]: ./relations.md +[queries]: ./queries.md +[pydantic]: https://pydantic-docs.helpmanual.io/ +[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ +[sqlalchemy-metadata]: https://docs.sqlalchemy.org/en/13/core/metadata.html +[databases]: https://github.com/encode/databases +[sqlalchemy connection string]: https://docs.sqlalchemy.org/en/13/core/engines.html#database-urls \ No newline at end of file diff --git a/docs/pydantic.md b/docs/pydantic.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/queries.md b/docs/queries.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/relations.md b/docs/relations.md new file mode 100644 index 0000000..e69de29 diff --git a/docs_src/models/docs001.py b/docs_src/models/docs001.py new file mode 100644 index 0000000..6aba0ed --- /dev/null +++ b/docs_src/models/docs001.py @@ -0,0 +1,16 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) diff --git a/docs_src/models/docs002.py b/docs_src/models/docs002.py new file mode 100644 index 0000000..0886dcc --- /dev/null +++ b/docs_src/models/docs002.py @@ -0,0 +1,19 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(orm.Model): + # if you omit this parameter it will be created automatically + # as class.__name__.lower()+'s' -> "courses" in this example + __tablename__ = "my_courses" + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) diff --git a/docs_src/models/docs003.py b/docs_src/models/docs003.py new file mode 100644 index 0000000..bd108d4 --- /dev/null +++ b/docs_src/models/docs003.py @@ -0,0 +1,33 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) + +print(Course.__pydantic_model__.__fields__) +""" +Will produce: +{'completed': ModelField(name='completed', + type=bool, + required=False, + default=False), + 'id': ModelField(name='id', + type=Optional[int], + required=False, + default=None), + 'name': ModelField(name='name', + type=Optional[str], + required=False, + default=None)} +""" diff --git a/docs_src/models/docs004.py b/docs_src/models/docs004.py new file mode 100644 index 0000000..f2f06ba --- /dev/null +++ b/docs_src/models/docs004.py @@ -0,0 +1,22 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) + +print(Course.__table__.columns) +""" +Will produce: +['courses.id', 'courses.name', 'courses.completed'] +""" diff --git a/docs_src/models/docs005.py b/docs_src/models/docs005.py new file mode 100644 index 0000000..cf33c9d --- /dev/null +++ b/docs_src/models/docs005.py @@ -0,0 +1,51 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) + +print(Course.__model_fields__) +""" +Will produce: +{ +'id': {'name': 'id', + 'primary_key': True, + 'autoincrement': True, + 'nullable': False, + 'default': None, + 'server_default': None, + 'index': None, + 'unique': None, + 'pydantic_only': False}, +'name': {'name': 'name', + 'primary_key': False, + 'autoincrement': False, + 'nullable': True, + 'default': None, + 'server_default': None, + 'index': None, + 'unique': None, + 'pydantic_only': False, + 'length': 100}, +'completed': {'name': 'completed', + 'primary_key': False, + 'autoincrement': False, + 'nullable': True, + 'default': False, + 'server_default': None, + 'index': None, + 'unique': None, + 'pydantic_only': False} +} +""" diff --git a/docs_src/models/docs006.py b/docs_src/models/docs006.py new file mode 100644 index 0000000..fc0fbef --- /dev/null +++ b/docs_src/models/docs006.py @@ -0,0 +1,41 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Department(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) + department = orm.ForeignKey(Department) + + +department = Department(name="Science") +course = Course(name="Math", completed=False, department=department) + +print('name' in course.__dict__) +# False <- property name is not stored on Course instance +print(course.name) +# Math <- value returned from underlying pydantic model +print('department' in course.__dict__) +# False <- related model is not stored on Course instance +print(course.department) +# Department(id=None, name='Science') <- Department model +# returned from RelationshipManager +print(course.department.name) +# Science \ No newline at end of file diff --git a/docs_src/models/docs007.py b/docs_src/models/docs007.py new file mode 100644 index 0000000..2bd7af1 --- /dev/null +++ b/docs_src/models/docs007.py @@ -0,0 +1,22 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) + + +course = Course(name="Painting for dummies", completed=False) +await course.save() + +await Course.objects.create(name="Painting for dummies", completed=False) diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..d63f952 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,29 @@ +site_name: Async ORM +nav: + - Home: index.md + - Models: models.md + - Fields: fields.md + - Relations: relations.md + - Queries: queries.md + - Pydantic models: pydantic.md + - Use with Fastapi: fastapi.md +theme: + name: material + highlightjs: true + hljs_languages: + - python + palette: + primary: indigo +markdown_extensions: + - admonition + - pymdownx.superfences + - pymdownx.snippets: + base_path: docs + - pymdownx.inlinehilite + - pymdownx.highlight: + linenums: true +extra_javascript: + - https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js + - javascripts/config.js +extra_css: + - https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/default.min.css \ No newline at end of file diff --git a/orm/fields/base.py b/orm/fields/base.py index d405679..2a0feb5 100644 --- a/orm/fields/base.py +++ b/orm/fields/base.py @@ -11,18 +11,8 @@ if TYPE_CHECKING: # pragma no cover class BaseField: __type__ = None - def __init__(self, *args: Any, **kwargs: Any) -> None: - name = kwargs.pop("name", None) - args = list(args) - if args: - if isinstance(args[0], str): - if name is not None: - raise ModelDefinitionError( - "Column name cannot be passed positionally and as a keyword." - ) - name = args.pop(0) - - self.name = name + def __init__(self, **kwargs: Any) -> None: + self.name = None self._populate_from_kwargs(kwargs) def _populate_from_kwargs(self, kwargs: Dict) -> None: @@ -64,7 +54,7 @@ class BaseField: return False def get_column(self, name: str = None) -> sqlalchemy.Column: - self.name = self.name or name + self.name = name constraints = self.get_constraints() return sqlalchemy.Column( self.name, @@ -87,3 +77,6 @@ class BaseField: def expand_relationship(self, value: Any, child: "Model") -> Any: return value + + def __repr__(self): # pragma no cover + return str(self.__dict__) diff --git a/orm/fields/decorators.py b/orm/fields/decorators.py index 4deb597..ae4e498 100644 --- a/orm/fields/decorators.py +++ b/orm/fields/decorators.py @@ -14,8 +14,8 @@ class RequiredParams: old_init = model_field_class.__init__ model_field_class._old_init = old_init - def __init__(instance: "BaseField", *args: Any, **kwargs: Any) -> None: - super(instance.__class__, instance).__init__(*args, **kwargs) + def __init__(instance: "BaseField", **kwargs: Any) -> None: + super(instance.__class__, instance).__init__(**kwargs) for arg in self._required: if arg not in kwargs: raise ModelDefinitionError( diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py index d10806b..37f1108 100644 --- a/orm/models/metaclass.py +++ b/orm/models/metaclass.py @@ -75,6 +75,8 @@ def sqlalchemy_columns_from_model_fields( } for field_name, field in model_fields.items(): if field.primary_key: + if pkname is not None: + raise ModelDefinitionError("Only one primary key column is allowed.") pkname = field_name if not field.pydantic_only: columns.append(field.get_column(field_name)) @@ -100,7 +102,8 @@ class ModelMetaclass(type): if attrs.get("__abstract__"): return new_model - tablename = attrs["__tablename__"] + tablename = attrs.get("__tablename__", name.lower() + "s") + attrs["__tablename__"] = tablename metadata = attrs["__metadata__"] # sqlalchemy table creation diff --git a/orm/queryset/clause.py b/orm/queryset/clause.py index f587da6..e70aaa1 100644 --- a/orm/queryset/clause.py +++ b/orm/queryset/clause.py @@ -144,19 +144,21 @@ class QueryClause: ) -> Tuple[str, bool]: has_escaped_character = False - if op in ["contains", "icontains"]: - if isinstance(value, orm.Model): - raise QueryDefinitionError( - "You cannot use contains and icontains with instance of the Model" - ) + if op not in ["contains", "icontains"]: + return value, has_escaped_character - has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) + if isinstance(value, orm.Model): + raise QueryDefinitionError( + "You cannot use contains and icontains with instance of the Model" + ) - if has_escaped_character: - # enable escape modifier - for char in ESCAPE_CHARACTERS: - value = value.replace(char, f"\\{char}") - value = f"%{value}%" + has_escaped_character = any(c for c in ESCAPE_CHARACTERS if c in value) + + if has_escaped_character: + # enable escape modifier + for char in ESCAPE_CHARACTERS: + value = value.replace(char, f"\\{char}") + value = f"%{value}%" return value, has_escaped_character diff --git a/orm/queryset/query.py b/orm/queryset/query.py index 9592561..133d77f 100644 --- a/orm/queryset/query.py +++ b/orm/queryset/query.py @@ -52,8 +52,7 @@ class Query: if ( not self.model_cls.__model_fields__[key].nullable and isinstance( - self.model_cls.__model_fields__[key], - orm.fields.foreign_key.ForeignKey, + self.model_cls.__model_fields__[key], orm.fields.ForeignKey, ) and key not in self._select_related ): diff --git a/orm/relations.py b/orm/relations.py index c7ef8b6..269d027 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -2,7 +2,7 @@ import pprint import string import uuid from random import choices -from typing import Dict, List, TYPE_CHECKING, Union +from typing import List, TYPE_CHECKING, Union from weakref import proxy from orm import ForeignKey @@ -15,38 +15,20 @@ def get_table_alias() -> str: return "".join(choices(string.ascii_uppercase, k=2)) + uuid.uuid4().hex[:4] -def get_relation_config( - relation_type: str, table_name: str, field: ForeignKey -) -> Dict[str, str]: - alias = get_table_alias() - config = { - "type": relation_type, - "table_alias": alias, - "source_table": table_name - if relation_type == "primary" - else field.to.__tablename__, - "target_table": field.to.__tablename__ - if relation_type == "primary" - else table_name, - } - return config - - class RelationshipManager: def __init__(self) -> None: self._relations = dict() + self._aliases = dict() def add_relation_type( self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str ) -> None: if relations_key not in self._relations: - self._relations[relations_key] = get_relation_config( - "primary", table_name, field - ) + self._relations[relations_key] = {"type": "primary"} + self._aliases[f"{table_name}_{field.to.__tablename__}"] = get_table_alias() if reverse_key not in self._relations: - self._relations[reverse_key] = get_relation_config( - "reverse", table_name, field - ) + self._relations[reverse_key] = {"type": "reverse"} + self._aliases[f"{field.to.__tablename__}_{table_name}"] = get_table_alias() def deregister(self, model: "FakePydantic") -> None: for rel_type in self._relations.keys(): @@ -57,10 +39,11 @@ class RelationshipManager: def add_relation( self, parent: "FakePydantic", child: "FakePydantic", virtual: bool = False, ) -> None: - parent_id = parent._orm_id - child_id = child._orm_id - parent_name = parent.get_name() - child_name = child.get_name() + parent_id, child_id = parent._orm_id, child._orm_id + parent_name, child_name = ( + parent.get_name(title=True), + child.get_name(title=True), + ) if virtual: child_name, parent_name = parent_name, child_name child_id, parent_id = parent_id, child_id @@ -68,11 +51,11 @@ class RelationshipManager: else: child = proxy(child) - parent_relation_name = parent_name.lower().title() + "_" + child_name + "s" + parent_relation_name = parent_name + "_" + child_name.lower() + "s" parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) self.append_related_model(parents_list, child) - child_relation_name = child_name.lower().title() + "_" + parent_name + child_relation_name = child_name + "_" + parent_name.lower() children_list = self._relations[child_relation_name].setdefault(child_id, []) self.append_related_model(children_list, parent) @@ -102,13 +85,7 @@ class RelationshipManager: return self._relations[relations_key][instance._orm_id] def resolve_relation_join(self, from_table: str, to_table: str) -> str: - for relation_name, relation in self._relations.items(): - if ( - relation["source_table"] == from_table - and relation["target_table"] == to_table - ): - return self._relations[relation_name]["table_alias"] - return "" + return self._aliases.get(f"{from_table}_{to_table}", "") def __str__(self) -> str: # pragma no cover return pprint.pformat(self._relations, indent=4, width=1) diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index e7f8e0b..8ea9289 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -43,8 +43,8 @@ fields_to_check = [ class ExampleModel2(Model): __tablename__ = "example2" __metadata__ = metadata - test = fields.Integer(name="test12", primary_key=True) - test_string = fields.String("test_string2", length=250) + test = fields.Integer(primary_key=True) + test_string = fields.String(length=250) @pytest.fixture() @@ -93,49 +93,44 @@ def test_sqlalchemy_table_is_created(example): assert all([field in example.__table__.columns for field in fields_to_check]) -def test_double_column_name_in_model_definition(): - with pytest.raises(ModelDefinitionError): - - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - test_string = fields.String("test_string2", name="test_string2", length=250) - - def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata - test_string = fields.String(name="test_string2", length=250) + test_string = fields.String(length=250) + +def test_two_pks_in_model_definition(): + with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): + __tablename__ = "example3" + __metadata__ = metadata + id = fields.Integer(primary_key=True) + test_string = fields.String(length=250, primary_key=True) def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.Integer(name="test12", primary_key=True, pydantic_only=True) + test = fields.Integer(primary_key=True, pydantic_only=True) def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.Decimal(name="test12", primary_key=True) + test = fields.Decimal(primary_key=True) def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata - test = fields.String(name="test12", primary_key=True) + test = fields.String(primary_key=True) def test_json_conversion_in_model(): From 8c7051b07eb4860e1756aaf7bb20ec5d6dfdd26e Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 13 Aug 2020 12:54:42 +0200 Subject: [PATCH 52/62] finish fields docs intial ver,add test for related name, fix child_name(s) in reverse relations --- .coverage | Bin 53248 -> 53248 bytes docs/fields.md | 206 +++++++++++++++++++++++++++++++++++ docs/index.md | 214 ++++++++++++++++++++++++++++++++++--- docs_src/fields/docs001.py | 36 +++++++ docs_src/fields/docs002.py | 36 +++++++ docs_src/fields/docs003.py | 41 +++++++ orm/fields/base.py | 4 +- orm/fields/foreign_key.py | 19 ++-- orm/models/metaclass.py | 12 +-- orm/relations.py | 19 ++-- tests/test_foreign_keys.py | 26 ++++- 11 files changed, 572 insertions(+), 41 deletions(-) create mode 100644 docs_src/fields/docs001.py create mode 100644 docs_src/fields/docs002.py create mode 100644 docs_src/fields/docs003.py diff --git a/.coverage b/.coverage index a41bf1cbb0cf75a863d6c855a07c6e51169c3201..3afbc70f18fa0551be087317173b260cdce545d2 100644 GIT binary patch delta 226 zcmV<803H8;paX!Q1F!}l3I_lWLJu(y6%Plq5fJwdlMpW;A~PT{GCD9eIx;p70s|Wt zc4cyNX>V>dE;24Lfja>N9fE6EpTLFv;Hp#P% by default False. + +Sets the primary key column on a table, foreign keys always refer to the pk of the `Model`. + +Used in sql only. + +### autoincrement + +`autoincrement`: `bool` = `primary_key and type == int` -> defaults to True if column is a primary key and of type Integer, otherwise False. + +Can be only used with int fields. + +If a field has autoincrement it becomes optional. + +Used only in sql. + +### nullable + +`nullable`: `bool` = `not primary_key` -> defaults to False for primary key column, and True for all other. + +Specifies if field is optional or required, used both with sql and pydantic. + +!!!note + By default all `ForeignKeys` are also nullable, meaning the related `Model` is not required. + + If you change the `ForeignKey` column to `nullable`, it not only becomes required, it changes also the way in which data is loaded in queries. + + If you select `Model` without explicitly adding related `Model` assigned by not nullable `ForeignKey`, the `Model` is still gona be appended automatically, see example below. + +```Python hl_lines="24 32 33 34 35 37 38 39 40 41" +--8<-- "../docs_src/fields/docs003.py" +``` + +!!!info + If you want to know more about how you can preload related models during queries and how the relations work read the [queries][queries] and [relations][relations] sections. + + +### default + +`default`: `Any` = `None` -> defaults to None. + +A default value used if no other value is passed. + +In sql invoked on an insert, used during pydantic model definition. + +If the field has a default value it becomes optional. + +You can pass a static value or a Callable (function etc.) + +Used both in sql and pydantic. + +### server default + +`server_default`: `Any` = `None` -> defaults to None. + +A default value used if no other value is passed. + +In sql invoked on the server side so you can pass i.e. sql function (like now() wrapped in sqlalchemy text() clause). + +If the field has a server_default value it becomes optional. + +You can pass a static value or a Callable (function etc.) + +Used in sql only. + +### index + +`index`: `bool` = `False` -> by default False, + +Sets the index on a table's column. + +Used in sql only. + +### unique + +`unique`: `bool` = `False` + +Sets the unique constraint on a table's column. + +Used in sql only. + +## Fields Types + +### String + +`String(length)` has a required `length` parameter. + +* Sqlalchemy column: `sqlalchemy.String` +* Type (used for pydantic): `str` + +### Text + +`Text()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Text` +* Type (used for pydantic): `str` + +### Boolean + +`Boolean()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Boolean` +* Type (used for pydantic): `bool` + +### Integer + +`Integer()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Integer` +* Type (used for pydantic): `int` + +### BigInteger + +`BigInteger()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.BigInteger` +* Type (used for pydantic): `int` + +### Float + +`Float()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Float` +* Type (used for pydantic): `float` + +### Decimal + +`Decimal(lenght, precision)` has required `length` and `precision` parameters. + +* Sqlalchemy column: `sqlalchemy.DECIMAL` +* Type (used for pydantic): `decimal.Decimal` + +### Date + +`Date()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Date` +* Type (used for pydantic): `datetime.date` + +### Time + +`Time()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.Time` +* Type (used for pydantic): `datetime.time` + +### DateTime + +`DateTime()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.DateTime` +* Type (used for pydantic): `datetime.datetime` + +### JSON + +`JSON()` has no required parameters. + +* Sqlalchemy column: `sqlalchemy.JSON` +* Type (used for pydantic): `pydantic.Json` + +### ForeignKey + +`ForeignKey(to, related_name=None)` has required parameters `to` that takes target `Model` class. + +Sqlalchemy column and Type are automatically taken from target `Model`. + +* Sqlalchemy column: class of a target `Model` primary key column +* Type (used for pydantic): type of a target `Model` primary key column + +`ForeignKey` fields are automatically registering reverse side of the relation. + +By default it's child (source) `Model` name + s, like courses in snippet below: + +```Python hl_lines="25 31" +--8<-- "../docs_src/fields/docs001.py" +``` + +But you can overwrite this name by providing `related_name` parameter like below: + +```Python hl_lines="25 30" +--8<-- "../docs_src/fields/docs002.py" +``` + +!!!tip + Since related models are coming from Relationship Manager the reverse relation on access returns list of `wekref.proxy` to avoid circular references. + +!!!info + All relations are stored in lists, but when you access parent `Model` the ORM is unpacking the value for you. + Read more in [relations][relations]. + +[relations]: ./relations.md +[queries]: ./queries.md \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 000ea34..24756fe 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,17 +1,207 @@ -# Welcome to MkDocs +# Async-ORM -For full documentation visit [mkdocs.org](https://www.mkdocs.org). +

+ + Build Status + + + Coverage + + +CodeFactor + + +Codacy + +

-## Commands +The `async-orm` package is an async ORM for Python, with support for Postgres, +MySQL, and SQLite. ORM is built with: -* `mkdocs new [dir-name]` - Create a new project. -* `mkdocs serve` - Start the live-reloading docs server. -* `mkdocs build` - Build the documentation site. -* `mkdocs -h` - Print help message and exit. + * [`SQLAlchemy core`][sqlalchemy-core] for query building. + * [`databases`][databases] for cross-database async support. + * [`pydantic`][pydantic] for data validation. -## Project layout +Because ORM is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide +database migrations. - mkdocs.yml # The configuration file. - docs/ - index.md # The documentation homepage. - ... # Other markdown pages, images and other files. +The goal was to create a simple orm that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. +Initial work was inspired by [`encode/orm`][encode/orm]. +The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. + +**async-orm is still under development:** We recommend pinning any dependencies with `aorm~=0.0.1` + +**Note**: Use `ipython` to try this from the console, since it supports `await`. + +```python +import databases +import orm +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Note(orm.Model): + __tablename__ = "notes" + __database__ = database + __metadata__ = metadata + + # primary keys of type int by dafault are set to autoincrement + id = orm.Integer(primary_key=True) + text = orm.String(length=100) + completed = orm.Boolean(default=False) + +# Create the database +engine = sqlalchemy.create_engine(str(database.url)) +metadata.create_all(engine) + +# .create() +await Note.objects.create(text="Buy the groceries.", completed=False) +await Note.objects.create(text="Call Mum.", completed=True) +await Note.objects.create(text="Send invoices.", completed=True) + +# .all() +notes = await Note.objects.all() + +# .filter() +notes = await Note.objects.filter(completed=True).all() + +# exact, iexact, contains, icontains, lt, lte, gt, gte, in +notes = await Note.objects.filter(text__icontains="mum").all() + +# .get() +note = await Note.objects.get(id=1) + +# .update() +await note.update(completed=True) + +# .delete() +await note.delete() + +# 'pk' always refers to the primary key +note = await Note.objects.get(pk=2) +note.pk # 2 +``` + +ORM supports loading and filtering across foreign keys... + +```python +import databases +import orm +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(orm.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Track(orm.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album) + title = orm.String(length=100) + position = orm.Integer() + + +# Create some records to work with. +malibu = await Album.objects.create(name="Malibu") +await Track.objects.create(album=malibu, title="The Bird", position=1) +await Track.objects.create(album=malibu, title="Heart don't stand a chance", position=2) +await Track.objects.create(album=malibu, title="The Waters", position=3) + +fantasies = await Album.objects.create(name="Fantasies") +await Track.objects.create(album=fantasies, title="Help I'm Alive", position=1) +await Track.objects.create(album=fantasies, title="Sick Muse", position=2) + + +# Fetch an instance, without loading a foreign key relationship on it. +track = await Track.objects.get(title="The Bird") + +# We have an album instance, but it only has the primary key populated +print(track.album) # Album(id=1) [sparse] +print(track.album.pk) # 1 +print(track.album.name) # Raises AttributeError + +# Load the relationship from the database +await track.album.load() +assert track.album.name == "Malibu" + +# This time, fetch an instance, loading the foreign key relationship. +track = await Track.objects.select_related("album").get(title="The Bird") +assert track.album.name == "Malibu" + +# By default you also get a second side of the relation +# constructed as lowercase source model name +'s' (tracks in this case) +# you can also provide custom name with parameter related_name +album = await Album.objects.select_related("tracks").all() +assert len(album.tracks) == 3 + +# Fetch instances, with a filter across an FK relationship. +tracks = Track.objects.filter(album__name="Fantasies") +assert len(tracks) == 2 + +# Fetch instances, with a filter and operator across an FK relationship. +tracks = Track.objects.filter(album__name__iexact="fantasies") +assert len(tracks) == 2 + +# Limit a query +tracks = await Track.objects.limit(1).all() +assert len(tracks) == 1 +``` + +## Data types + +The following keyword arguments are supported on all field types. + + * `primary_key` + * `nullable` + * `default` + * `server_default` + * `index` + * `unique` + +## Model Fields + +### Common parameters + +All fields are required unless one of the following is set: + + * `nullable` - Creates a nullable column. Sets the default to `None`. + * `default` - Set a default value for the field. + * `server_default` - Set a default value for the field on server side (like sqlalchemy's `func.now()`). + * `primary key` - Set a primary key on a column. + * `autoincrement` - When a column is set to primary key and autoincrement is set on this column. + Autoincrement is set by default on int primary keys. + +### Fields Types + +* `orm.String(length)` +* `orm.Text()` +* `orm.Boolean()` +* `orm.Integer()` +* `orm.Float()` +* `orm.Date()` +* `orm.Time()` +* `orm.DateTime()` +* `orm.JSON()` +* `orm.BigInteger()` +* `orm.Decimal(lenght, precision)` + +[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ +[databases]: https://github.com/encode/databases +[pydantic]: https://pydantic-docs.helpmanual.io/ +[encode/orm]: https://github.com/encode/orm/ +[alembic]: https://alembic.sqlalchemy.org/en/latest/ +[fastapi]: https://fastapi.tiangolo.com/ \ No newline at end of file diff --git a/docs_src/fields/docs001.py b/docs_src/fields/docs001.py new file mode 100644 index 0000000..f28b9d0 --- /dev/null +++ b/docs_src/fields/docs001.py @@ -0,0 +1,36 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Department(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) + department = orm.ForeignKey(Department) + + +department = Department(name='Science') +course = Course(name='Math', completed=False, department=department) + +print(department.courses[0]) +# Will produce: +# Course(id=None, +# name='Math', +# completed=False, +# department=Department(id=None, name='Science')) diff --git a/docs_src/fields/docs002.py b/docs_src/fields/docs002.py new file mode 100644 index 0000000..5c9f4a9 --- /dev/null +++ b/docs_src/fields/docs002.py @@ -0,0 +1,36 @@ +import databases +import sqlalchemy + +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Department(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Course(orm.Model): + __database__ = database + __metadata__ = metadata + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + completed = orm.Boolean(default=False) + department = orm.ForeignKey(Department, related_name="my_courses") + +department = Department(name='Science') +course = Course(name='Math', completed=False, department=department) + +print(department.my_courses[0]) +# Will produce: +# Course(id=None, +# name='Math', +# completed=False, +# department=Department(id=None, name='Science')) + diff --git a/docs_src/fields/docs003.py b/docs_src/fields/docs003.py new file mode 100644 index 0000000..c3df501 --- /dev/null +++ b/docs_src/fields/docs003.py @@ -0,0 +1,41 @@ +import orm +import databases +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(orm.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Track(orm.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album, nullable=False) + title = orm.String(length=100) + position = orm.Integer() + + +album = await Album.objects.create(name="Brooklyn") +await Track.objects.create(album=album, title="The Bird", position=1) + +# explicit preload of related Album Model +track = await Track.objects.select_related("album").get(title="The Bird") +assert track.album.name == 'Brooklyn' +# Will produce: True + +# even without explicit select_related if ForeignKey is not nullable, +# the Album Model is still preloaded. +track2 = await Track.objects.get(title="The Bird") +assert track2.album.name == 'Brooklyn' +# Will produce: True diff --git a/orm/fields/base.py b/orm/fields/base.py index 2a0feb5..73b3131 100644 --- a/orm/fields/base.py +++ b/orm/fields/base.py @@ -35,12 +35,12 @@ class BaseField: @property def is_required(self) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not self.nullable and not self.has_default and not self.is_auto_primary_key ) @property def default_value(self) -> Any: - default = self.default if self.default is not None else self.server_default + default = self.default return default() if callable(default) else default @property diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index b32a887..81398e2 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -25,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -50,7 +50,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def _extract_model_from_sequence( - self, value: List, child: "Model" + self, value: List, child: "Model" ) -> Union["Model", List["Model"]]: return [self.expand_relationship(val, child) for val in value] @@ -75,10 +75,11 @@ class ForeignKey(BaseField): return model def register_relation(self, model: "Model", child: "Model") -> None: - model._orm_relationship_manager.add_relation(model, child, virtual=self.virtual) + child_model_name = self.related_name or child.get_name() + model._orm_relationship_manager.add_relation(model, child, child_model_name, virtual=self.virtual) def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py index 37f1108..13a237c 100644 --- a/orm/models/metaclass.py +++ b/orm/models/metaclass.py @@ -28,8 +28,8 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: - child_relation_name = field.to.get_name(title=True) + "_" + name.lower() + "s" - reverse_name = field.related_name or child_relation_name + child_relation_name = field.to.get_name(title=True) + "_" + (field.related_name or (name.lower() + "s")) + reverse_name = child_relation_name relation_name = name.lower().title() + "_" + field.to.get_name() relationship_manager.add_relation_type( relation_name, reverse_name, field, table_name @@ -43,14 +43,14 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: model.__fields__[child_model_name] = ModelField( name=child_model_name, @@ -64,7 +64,7 @@ def register_reverse_model_fields( def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: columns = [] pkname = None diff --git a/orm/relations.py b/orm/relations.py index 269d027..59ec981 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -37,25 +37,28 @@ class RelationshipManager: del self._relations[rel_type][model._orm_id] def add_relation( - self, parent: "FakePydantic", child: "FakePydantic", virtual: bool = False, + self, + parent: "FakePydantic", + child: "FakePydantic", + child_model_name: str, + virtual: bool = False, ) -> None: parent_id, child_id = parent._orm_id, child._orm_id - parent_name, child_name = ( - parent.get_name(title=True), - child.get_name(title=True), - ) + parent_name =parent.get_name(title=True) + child_name = child_model_name if child.get_name() != child_model_name else child.get_name()+'s' if virtual: - child_name, parent_name = parent_name, child_name + child_name, parent_name = parent_name, child.get_name() child_id, parent_id = parent_id, child_id child, parent = parent, proxy(child) + child_name = child_name.lower()+'s' else: child = proxy(child) - parent_relation_name = parent_name + "_" + child_name.lower() + "s" + parent_relation_name = parent_name.title() + "_" + child_name parents_list = self._relations[parent_relation_name].setdefault(parent_id, []) self.append_related_model(parents_list, child) - child_relation_name = child_name + "_" + parent_name.lower() + child_relation_name = child.get_name(title=True) + "_" + parent_name.lower() children_list = self._relations[child_relation_name].setdefault(child_id, []) self.append_related_model(children_list, parent) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 7121e5a..aaea6ff 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -3,7 +3,6 @@ import pytest import sqlalchemy import orm -import orm.fields.foreign_key from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from tests.settings import DATABASE_URL @@ -26,11 +25,21 @@ class Track(orm.Model): __database__ = database id = orm.Integer(primary_key=True) - album = orm.fields.foreign_key.ForeignKey(Album) + album = orm.ForeignKey(Album) title = orm.String(length=100) position = orm.Integer() +class Cover(orm.Model): + __tablename__ = "covers" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album, related_name='cover_pictures') + title = orm.String(length=100) + + class Organisation(orm.Model): __tablename__ = "org" __metadata__ = metadata @@ -46,7 +55,7 @@ class Team(orm.Model): __database__ = database id = orm.Integer(primary_key=True) - org = orm.fields.foreign_key.ForeignKey(Organisation) + org = orm.ForeignKey(Organisation) name = orm.String(length=100) @@ -56,7 +65,7 @@ class Member(orm.Model): __database__ = database id = orm.Integer(primary_key=True) - team = orm.fields.foreign_key.ForeignKey(Team) + team = orm.ForeignKey(Team) email = orm.String(length=100) @@ -81,6 +90,15 @@ async def test_setting_explicitly_empty_relation(): assert track.album is None +@pytest.mark.asyncio +async def test_related_name(): + async with database: + album = await Album.objects.create(name="Vanilla") + await Cover.objects.create(album=album, title="The cover file") + + assert len(album.cover_pictures) == 1 + + @pytest.mark.asyncio async def test_model_crud(): async with database: From 6b0cfdbfd36772664f68d6db5351c2c9cf3e2c0f Mon Sep 17 00:00:00 2001 From: collerek Date: Thu, 13 Aug 2020 17:10:13 +0200 Subject: [PATCH 53/62] work on relations docs --- docs/relations.md | 96 +++++++++++++++++++++++++++++++++++ docs_src/relations/docs001.py | 26 ++++++++++ docs_src/relations/docs002.py | 39 ++++++++++++++ docs_src/relations/docs003.py | 44 ++++++++++++++++ 4 files changed, 205 insertions(+) create mode 100644 docs_src/relations/docs001.py create mode 100644 docs_src/relations/docs002.py create mode 100644 docs_src/relations/docs003.py diff --git a/docs/relations.md b/docs/relations.md index e69de29..2fbe479 100644 --- a/docs/relations.md +++ b/docs/relations.md @@ -0,0 +1,96 @@ +# Relations + +## Defining a relationship + +### Foreign Key + +To define a relationship you simply need to create a ForeignKey field on one `Model` and point it to another `Model`. + +```Python hl_lines="24" +--8<-- "../docs_src/relations/docs001.py" +``` + +It automatically creates an sql foreign key constraint on a underlying table as well as nested pydantic model in the definition. + + +```Python hl_lines="29 33" +--8<-- "../docs_src/relations/docs002.py" +``` + +Of course it's handled for you so you don't have to delve deep into this but you can. + +### Reverse Relation + +At the same time the reverse relationship is registered automatically on parent model (target of `ForeignKey`). + +By default it's child (source) `Model` name + 's', like courses in snippet below: + +```Python hl_lines="25 31" +--8<-- "../docs_src/fields/docs001.py" +``` + +But you can overwrite this name by providing `related_name` parameter like below: + +```Python hl_lines="25 30" +--8<-- "../docs_src/fields/docs002.py" +``` + +!!!tip + Since related models are coming from Relationship Manager the reverse relation on access returns list of `wekref.proxy` to avoid circular references. + +## Relationship Manager + +Since orm uses Sqlalchemy core under the hood to prepare the queries, +the orm needs a way to uniquely identify each relationship between to tables to construct working queries. + +Imagine that you have models as following: + +```Python +--8<-- "../docs_src/relations/docs003.py" +``` + +Now imagine that you want to go from school class to student and his category and to teacher and his category. + +```Python +classes = await SchoolClass.objects.select_related( +["teachers__category", "students__category"]).all() +``` + +!!!note + To select related models use `select_related` method from `Model` `QuerySet`. + + Note that you use relation (`ForeignKey`) names and not the table names. + +Since you join two times to the same table it won't work by default -> you would need to use aliases for category tables and columns. + +But don't worry - orm can handle situations like this, as it uses the Relationship Manager which has it's aliases defined for all relationships. + +Each class is registered with the same instance of the RelationshipManager that you can access like this: + +```python +SchoolClass._orm_relationship_manager +``` + +It's the same object for all `Models` + +```python +print(Teacher._orm_relationship_manager == Student._orm_relationship_manager) +# will produce: True +``` + +You can even preview the alias used for any relation by passing two tables names. + +```python +print(Teacher._orm_relationship_manager.resolve_relation_join( +'students', 'categories')) +# will produce: KId1c6 (sample value) + +print(Teacher._orm_relationship_manager.resolve_relation_join( +'categories', 'students')) +# will produce: EFccd5 (sample value) +``` + +!!!note + The order that you pass the names matters -> as those are 2 different relationships depending on join order. + + As aliases are produced randomly you can be presented with different results. diff --git a/docs_src/relations/docs001.py b/docs_src/relations/docs001.py new file mode 100644 index 0000000..4a028c3 --- /dev/null +++ b/docs_src/relations/docs001.py @@ -0,0 +1,26 @@ +import orm +import databases +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(orm.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Track(orm.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album) + title = orm.String(length=100) + position = orm.Integer() \ No newline at end of file diff --git a/docs_src/relations/docs002.py b/docs_src/relations/docs002.py new file mode 100644 index 0000000..9e3dfda --- /dev/null +++ b/docs_src/relations/docs002.py @@ -0,0 +1,39 @@ +import orm +import databases +import sqlalchemy + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class Album(orm.Model): + __tablename__ = "album" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Track(orm.Model): + __tablename__ = "track" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + album = orm.ForeignKey(Album) + title = orm.String(length=100) + position = orm.Integer() + + +print(Track.__table__.columns['album'].__repr__()) +# Will produce: +# Column('album', Integer(), ForeignKey('album.id'), table=) + +print(Track.__pydantic_model__.__fields__['album']) +# Will produce: +# ModelField( +# name='album' +# type=Optional[Album] +# required=False +# default=None) diff --git a/docs_src/relations/docs003.py b/docs_src/relations/docs003.py new file mode 100644 index 0000000..aa54c29 --- /dev/null +++ b/docs_src/relations/docs003.py @@ -0,0 +1,44 @@ +import databases +import sqlalchemy +import orm + +database = databases.Database("sqlite:///db.sqlite") +metadata = sqlalchemy.MetaData() + + +class SchoolClass(orm.Model): + __tablename__ = "schoolclasses" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Category(orm.Model): + __tablename__ = "categories" + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + + +class Student(orm.Model): + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + schoolclass = orm.ForeignKey(SchoolClass) + category = orm.ForeignKey(Category) + + +class Teacher(orm.Model): + __metadata__ = metadata + __database__ = database + + id = orm.Integer(primary_key=True) + name = orm.String(length=100) + schoolclass = orm.ForeignKey(SchoolClass) + category = orm.ForeignKey(Category) From 002f27f21e4fb3a4ffe161bc3868c5395edb1208 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 14 Aug 2020 14:35:42 +0200 Subject: [PATCH 54/62] fix bug with infinite relation auto extraction, finish initial relations docs --- .coverage | Bin 53248 -> 53248 bytes docs/relations.md | 114 ++++++++++++++++++++++++++++++++- orm/queryset/query.py | 57 +++++++++-------- tests/test_same_table_joins.py | 2 +- 4 files changed, 143 insertions(+), 30 deletions(-) diff --git a/.coverage b/.coverage index 3afbc70f18fa0551be087317173b260cdce545d2..d661345fe2932c280b47cc65dad053a17e94bea7 100644 GIT binary patch delta 55 zcmV-70LcGgeZTMb?XULTz5e}u N_V?R%f3w(*C_$9K9Kiqp delta 55 zcmV-70LcG you would need to use aliases for category tables and columns. +Since you join two times to the same table (categories) it won't work by default -> you would need to use aliases for category tables and columns. But don't worry - orm can handle situations like this, as it uses the Relationship Manager which has it's aliases defined for all relationships. @@ -78,6 +89,8 @@ print(Teacher._orm_relationship_manager == Student._orm_relationship_manager) # will produce: True ``` +### Table aliases + You can even preview the alias used for any relation by passing two tables names. ```python @@ -94,3 +107,100 @@ print(Teacher._orm_relationship_manager.resolve_relation_join( The order that you pass the names matters -> as those are 2 different relationships depending on join order. As aliases are produced randomly you can be presented with different results. + +### Query automatic construction + +Orm is using those aliases during queries to both construct a meaningful and valid sql, +as well as later use it to extract proper columns for proper nested models. + +Running a previously mentioned query to select school classes and related teachers and students: + +```Python +classes = await SchoolClass.objects.select_related( +["teachers__category", "students__category"]).all() +``` + +Will result in a query like this (run under the hood): + +```sql +SELECT schoolclasses.id, + schoolclasses.name, + schoolclasses.department, + NZc8e2_students.id as NZc8e2_id, + NZc8e2_students.name as NZc8e2_name, + NZc8e2_students.schoolclass as NZc8e2_schoolclass, + NZc8e2_students.category as NZc8e2_category, + MYfe53_categories.id as MYfe53_id, + MYfe53_categories.name as MYfe53_name, + WA49a3_teachers.id as WA49a3_id, + WA49a3_teachers.name as WA49a3_name, + WA49a3_teachers.schoolclass as WA49a3_schoolclass, + WA49a3_teachers.category as WA49a3_category, + WZa13b_categories.id as WZa13b_id, + WZa13b_categories.name as WZa13b_name +FROM schoolclasses + LEFT OUTER JOIN students NZc8e2_students ON NZc8e2_students.schoolclass = schoolclasses.id + LEFT OUTER JOIN categories MYfe53_categories ON MYfe53_categories.id = NZc8e2_students.category + LEFT OUTER JOIN teachers WA49a3_teachers ON WA49a3_teachers.schoolclass = schoolclasses.id + LEFT OUTER JOIN categories WZa13b_categories ON WZa13b_categories.id = WA49a3_teachers.category +ORDER BY schoolclasses.id, NZc8e2_students.id, MYfe53_categories.id, WA49a3_teachers.id, WZa13b_categories.id +``` + +!!!note + As mentioned before the aliases are produced dynamically so the actual result might differ. + + Note that aliases are assigned to relations and not the tables, therefore the first table is always without an alias. + +### Returning related Models + +Each object in Relationship Manager is identified by orm_id which you can preview like this + +```python +category = Category(name='Math') +print(category._orm_id) +# will produce: c76046d9410c4582a656bf12a44c892c (sample value) +``` + +Each call to related `Model` is actually coming through the Manager which stores all +the relations in a dictionary and returns related `Models` by relation type (name) and by object _orm_id. + +Since we register both sides of the relation the side registering the relation +is always registering the other side as concrete model, +while the reverse relation is a weakref.proxy to avoid circular references. + +Sounds complicated but in reality it means something like this: + +```python +test_class = await SchoolClass.objects.create(name='Test') +student = await Student.objects.create(name='John', schoolclass=test_class) +# the relation to schoolsclass from student (i.e. when you call student.schoolclass) +# is a concrete one, meaning directy relating the schoolclass `Model` object +# On the other side calling test_class.students will result in a list of wekref.proxy objects +``` + +!!!tip + To learn more about queries and available methods please review [queries][queries] section. + +All relations are kept in lists, meaning that when you access related object the Relationship Manager is +searching itself for related models and get a list of them. + +But since child to parent relation is a many to one type, +the Manager is unpacking the first (and only) related model from a list and you get an actual `Model` instance instead of a list. + +Coming from parent to child relation (one to many) you always get a list of results. + +Translating this into concrete sample, the same as above: + +```python +test_class = await SchoolClass.objects.create(name='Test') +student = await Student.objects.create(name='John', schoolclass=test_class) + +student.schoolclass # return a test_class instance extracted from relationship list +test_class.students # return a list of related wekref.proxy refering related students `Models` + +``` + +!!!tip + You can preview all relations currently registered by accessing Relationship Manager on any class/instance `Student._orm_relationship_manager._relations` + +[queries]: ./queries.md \ No newline at end of file diff --git a/orm/queryset/query.py b/orm/queryset/query.py index 133d77f..75f41f3 100644 --- a/orm/queryset/query.py +++ b/orm/queryset/query.py @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -38,6 +38,7 @@ class Query: self.auto_related = [] self.used_aliases = [] + self.already_checked = [] self.select_from = None self.columns = None @@ -50,11 +51,11 @@ class Query: for key in self.model_cls.__model_fields__: if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], orm.fields.ForeignKey, - ) - and key not in self._select_related + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], orm.fields.ForeignKey, + ) + and key not in self._select_related ): self._select_related = [key] + self._select_related @@ -96,32 +97,32 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: BaseField, field_name: str, rel_part: str ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( [x.startswith(prev_part_of_related) for x in self._select_related] ) - already_checked = any([x.startswith(rel_part) for x in self.auto_related]) + already_checked = any([x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to to_table = model_cls.__table__.name @@ -164,15 +165,15 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: for field_name, field in prev_model.__model_fields__.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: @@ -180,7 +181,7 @@ class Query: self.auto_related.append("__".join(rel_part.split("__")[:-1])) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): self._extract_auto_required_relations( prev_model=field.to, @@ -189,6 +190,7 @@ class Query: parent_virtual=field.virtual, ) else: + self.already_checked.append(rel_part) rel_part = "" def _include_auto_related_models(self) -> None: @@ -200,7 +202,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: @@ -225,3 +227,4 @@ class Query: self.order_bys = None self.auto_related = [] self.used_aliases = [] + self.already_checked = [] diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 4c27446..b1dd02a 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -120,7 +120,7 @@ async def test_right_tables_join(): async def test_multiple_reverse_related_objects(): async with database: classes = await SchoolClass.objects.select_related( - ["teachers__category", "students"] + ["teachers__category", "students__category"] ).all() assert classes[0].name == "Math" assert classes[0].students[1].name == "Jack" From c6b4f69c4db0278826ab5b7223d107914f2a73ed Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 14 Aug 2020 14:35:57 +0200 Subject: [PATCH 55/62] fix bug with infinite relation auto extraction, finish initial relations docs --- orm/fields/base.py | 2 +- orm/fields/foreign_key.py | 20 +++++++------- orm/models/metaclass.py | 14 ++++++---- orm/queryset/query.py | 56 ++++++++++++++++++++------------------- orm/relations.py | 18 ++++++++----- 5 files changed, 61 insertions(+), 49 deletions(-) diff --git a/orm/fields/base.py b/orm/fields/base.py index 73b3131..003d14e 100644 --- a/orm/fields/base.py +++ b/orm/fields/base.py @@ -35,7 +35,7 @@ class BaseField: @property def is_required(self) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not self.nullable and not self.has_default and not self.is_auto_primary_key ) @property diff --git a/orm/fields/foreign_key.py b/orm/fields/foreign_key.py index 81398e2..839ca95 100644 --- a/orm/fields/foreign_key.py +++ b/orm/fields/foreign_key.py @@ -25,12 +25,12 @@ def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": class ForeignKey(BaseField): def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, + self, + to: Type["Model"], + name: str = None, + related_name: str = None, + nullable: bool = True, + virtual: bool = False, ) -> None: super().__init__(nullable=nullable, name=name) self.virtual = virtual @@ -50,7 +50,7 @@ class ForeignKey(BaseField): return to_column.get_column_type() def _extract_model_from_sequence( - self, value: List, child: "Model" + self, value: List, child: "Model" ) -> Union["Model", List["Model"]]: return [self.expand_relationship(val, child) for val in value] @@ -76,10 +76,12 @@ class ForeignKey(BaseField): def register_relation(self, model: "Model", child: "Model") -> None: child_model_name = self.related_name or child.get_name() - model._orm_relationship_manager.add_relation(model, child, child_model_name, virtual=self.virtual) + model._orm_relationship_manager.add_relation( + model, child, child_model_name, virtual=self.virtual + ) def expand_relationship( - self, value: Any, child: "Model" + self, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: diff --git a/orm/models/metaclass.py b/orm/models/metaclass.py index 13a237c..f519791 100644 --- a/orm/models/metaclass.py +++ b/orm/models/metaclass.py @@ -28,7 +28,11 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: - child_relation_name = field.to.get_name(title=True) + "_" + (field.related_name or (name.lower() + "s")) + child_relation_name = ( + field.to.get_name(title=True) + + "_" + + (field.related_name or (name.lower() + "s")) + ) reverse_name = child_relation_name relation_name = name.lower().title() + "_" + field.to.get_name() relationship_manager.add_relation_type( @@ -43,14 +47,14 @@ def expand_reverse_relationships(model: Type["Model"]) -> None: parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: model.__fields__[child_model_name] = ModelField( name=child_model_name, @@ -64,7 +68,7 @@ def register_reverse_model_fields( def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: columns = [] pkname = None diff --git a/orm/queryset/query.py b/orm/queryset/query.py index 75f41f3..bf6234f 100644 --- a/orm/queryset/query.py +++ b/orm/queryset/query.py @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -51,11 +51,11 @@ class Query: for key in self.model_cls.__model_fields__: if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], orm.fields.ForeignKey, - ) - and key not in self._select_related + not self.model_cls.__model_fields__[key].nullable + and isinstance( + self.model_cls.__model_fields__[key], orm.fields.ForeignKey, + ) + and key not in self._select_related ): self._select_related = [key] + self._select_related @@ -97,32 +97,34 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: BaseField, field_name: str, rel_part: str ) -> bool: return isinstance(field, ForeignKey) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( [x.startswith(prev_part_of_related) for x in self._select_related] ) - already_checked = any([x.startswith(rel_part) for x in (self.auto_related + self.already_checked)]) + already_checked = any( + [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] + ) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: model_cls = join_params.model_cls.__model_fields__[part].to to_table = model_cls.__table__.name @@ -165,15 +167,15 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: for field_name, field in prev_model.__model_fields__.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: @@ -181,7 +183,7 @@ class Query: self.auto_related.append("__".join(rel_part.split("__")[:-1])) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): self._extract_auto_required_relations( prev_model=field.to, @@ -202,7 +204,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/orm/relations.py b/orm/relations.py index 59ec981..9df7ee7 100644 --- a/orm/relations.py +++ b/orm/relations.py @@ -38,19 +38,23 @@ class RelationshipManager: def add_relation( self, - parent: "FakePydantic", - child: "FakePydantic", - child_model_name: str, - virtual: bool = False, + parent: "FakePydantic", + child: "FakePydantic", + child_model_name: str, + virtual: bool = False, ) -> None: parent_id, child_id = parent._orm_id, child._orm_id - parent_name =parent.get_name(title=True) - child_name = child_model_name if child.get_name() != child_model_name else child.get_name()+'s' + parent_name = parent.get_name(title=True) + child_name = ( + child_model_name + if child.get_name() != child_model_name + else child.get_name() + "s" + ) if virtual: child_name, parent_name = parent_name, child.get_name() child_id, parent_id = parent_id, child_id child, parent = parent, proxy(child) - child_name = child_name.lower()+'s' + child_name = child_name.lower() + "s" else: child = proxy(child) From 0ebecc8610f4d5467e3db3186c95f109c7fc97a2 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 14 Aug 2020 15:24:44 +0200 Subject: [PATCH 56/62] finish initial queries docs --- docs/queries.md | 156 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/docs/queries.md b/docs/queries.md index e69de29..b931809 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -0,0 +1,156 @@ +# Queries + +## QuerySet + +Each Model is auto registered with a QuerySet that represents the underlaying query and it's options. + +Given the Models like this + +```Python +--8<-- "../docs_src/relations/docs001.py" +``` + +we can demonstrate available methods to fetch and save the data into the database. + +### create(**kwargs) + +Creates the model instance, saves it in a database and returns the updates model (with pk populated). +The allowed kwargs are `Model` fields names and proper value types. + +```python +malibu = await Album.objects.create(name="Malibu") +await Track.objects.create(album=malibu, title="The Bird", position=1) +``` + +The alternative is a split creation and persistence of the `Model`. +```python +malibu = Album(name="Malibu") +await malibu.save() +``` + +### load() + +By default when you query a table without prefetching related models, the orm will still construct +your related models, but populate them only with the pk value. + +```python +track = await Track.objects.get(name='The Bird') +track.album.pk # will return malibu album pk (1) +track.album.name # will return None + +# you need to actually load the data first +await track.album.load() +track.album.name # will return 'Malibu' +``` + +### get(**kwargs) + +Get's the first row from the db meeting the criteria set by kwargs. + +If no criteria set it will return the first row in db. + +Passing a criteria is actually calling filter(**kwargs) method described below. + +```python +track = await Track.objects.get(name='The Bird') +track2 = track = await Track.objects.get() +track == track2 # True since it's the only row in db +``` + +### all() + +Returns all rows from a database for given model + +```python +tracks = await Track.objects.select_related("album").all() +# will return a list of all Tracks +``` + +### filter(**kwargs) + +Allows you to filter by any `Model` attribute/field +as well as to fetch instances, with a filter across an FK relationship. + +```python +track = Track.objects.filter(name="The Bird").get() +# will return a track with name equal to 'The Bird' + +tracks = Track.objects.filter(album__name="Fantasies").all() +# will return all tracks where the related album name = 'Fantasies' +``` + +You can use special filter suffix to change the filter operands: + +* exact - like `album__name__exact='Malibu'` (exact match) +* iexact - like `album__name__iexact='malibu'` (exact match case insensitive) +* contains - like `album__name__conatins='Mal'` (sql like) +* icontains - like `album__name__icontains='mal'` (sql like case insensitive) +* in - like `album__name__in=['Malibu', 'Barclay']` (sql in) +* gt - like `position__gt=3` (sql >) +* gte - like `position__gte=3` (sql >=) +* lt - like `position__lt=3` (sql <) +* lte - like `position__lte=3` (sql <=) + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### select_related(*args) + +Allows to prefetch related models. + +To fetch related model use `ForeignKey` names. + +To chain related `Models` relation use double underscore. + +```python +album = await Album.objects.select_related("tracks").all() +# will return album will all related tracks +``` + +You can provide a string or a list of strings + +```python +classes = await SchoolClass.objects.select_related( +["teachers__category", "students"]).all() +# will return classes with teachers and teachers categories +# as well as classes students +``` + +!!!warning + If you set `ForeignKey` field as not nullable (so required) during + all queries the not nullable `Models` will be auto prefetched, even if you do not include them in select_related. + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### limit(int) + +You can limit the results to desired number of rows. + +```python +tracks = await Track.objects.limit(1).all() +# will return just one Track +``` + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` + +### offset(int) + +You can also offset the results by desired number of rows. + +```python +tracks = await Track.objects.offset(1).limit(1).all() +# will return just one Track, but this time the second one +``` + +!!!note + `filter()`, `select_related()`, `limit()` and `offset()` returns a QueySet instance so you can chain them together. + + Something like `Track.object.select_related("album").filter(album__name="Malibu").offset(1).limit(1).all()` \ No newline at end of file From 062d35168f373bdbdd1cfecac201244aaf26258f Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 14 Aug 2020 19:36:50 +0200 Subject: [PATCH 57/62] renames etc. --- .coverage | Bin 53248 -> 53248 bytes .gitignore | 3 +- README.md | 32 ++++++------ docs/fields.md | 2 +- docs/index.md | 64 ++++++++++++------------ docs/models.md | 18 +++++-- docs/queries.md | 2 +- docs/relations.md | 4 +- docs_src/fields/docs001.py | 18 +++---- docs_src/fields/docs002.py | 18 +++---- docs_src/fields/docs003.py | 18 +++---- docs_src/models/docs001.py | 10 ++-- docs_src/models/docs002.py | 10 ++-- docs_src/models/docs003.py | 10 ++-- docs_src/models/docs004.py | 10 ++-- docs_src/models/docs005.py | 10 ++-- docs_src/models/docs006.py | 18 +++---- docs_src/models/docs007.py | 10 ++-- docs_src/relations/docs001.py | 18 +++---- docs_src/relations/docs002.py | 18 +++---- docs_src/relations/docs003.py | 34 ++++++------- orm/models/__init__.py | 4 -- orm/queryset/__init__.py | 3 -- {orm => ormar}/__init__.py | 8 +-- {orm => ormar}/exceptions.py | 0 {orm => ormar}/fields/__init__.py | 6 +-- {orm => ormar}/fields/base.py | 4 +- {orm => ormar}/fields/decorators.py | 4 +- {orm => ormar}/fields/foreign_key.py | 8 +-- {orm => ormar}/fields/model_fields.py | 4 +- ormar/models/__init__.py | 4 ++ {orm => ormar}/models/fakepydantic.py | 14 +++--- {orm => ormar}/models/metaclass.py | 8 +-- {orm => ormar}/models/model.py | 6 +-- ormar/queryset/__init__.py | 3 ++ {orm => ormar}/queryset/clause.py | 10 ++-- {orm => ormar}/queryset/query.py | 10 ++-- {orm => ormar}/queryset/queryset.py | 10 ++-- {orm => ormar}/relations.py | 4 +- scripts/clean.sh | 2 +- scripts/publish.sh | 23 +++++++++ scripts/test.sh | 2 +- setup.cfg | 2 + setup.py | 67 ++++++++++++++++++++++++++ tests/test_columns.py | 18 +++---- tests/test_fastapi_usage.py | 17 +++---- tests/test_foreign_keys.py | 50 +++++++++---------- tests/test_model_definition.py | 6 +-- tests/test_models.py | 30 ++++++------ tests/test_same_table_joins.py | 43 ++++++++--------- 50 files changed, 398 insertions(+), 299 deletions(-) delete mode 100644 orm/models/__init__.py delete mode 100644 orm/queryset/__init__.py rename {orm => ormar}/__init__.py (72%) rename {orm => ormar}/exceptions.py (100%) rename {orm => ormar}/fields/__init__.py (72%) rename {orm => ormar}/fields/base.py (96%) rename {orm => ormar}/fields/decorators.py (91%) rename {orm => ormar}/fields/foreign_key.py (95%) rename {orm => ormar}/fields/model_fields.py (94%) create mode 100644 ormar/models/__init__.py rename {orm => ormar}/models/fakepydantic.py (95%) rename {orm => ormar}/models/metaclass.py (96%) rename {orm => ormar}/models/model.py (95%) create mode 100644 ormar/queryset/__init__.py rename {orm => ormar}/queryset/clause.py (96%) rename {orm => ormar}/queryset/query.py (97%) rename {orm => ormar}/queryset/queryset.py (96%) rename {orm => ormar}/relations.py (97%) mode change 100644 => 100755 scripts/clean.sh create mode 100644 scripts/publish.sh create mode 100644 setup.cfg create mode 100644 setup.py diff --git a/.coverage b/.coverage index d661345fe2932c280b47cc65dad053a17e94bea7..07f559488504b35b95fbb9a1f36d877952526add 100644 GIT binary patch delta 1216 zcmZuwSxgf_7~V46yE{L$$WbU_d{gNK2%4B^B3_{2lP0um)^@>$(rvdzXaujsS2URi zjT*0LqQ<)k#CY(eF)`|c8WmBZiFhHXiK&`sohhvp^=UTq9sfW7_wSCQXva}>F|#5U zbFdofpbT&%KkpYNS(dr87cI8dF0Qsh8Y(T?qG&;n>Z>u^&9_A@ z))M%LyP-~(Vbvq9og8Z{DU>xwpi9xhK_xWxVfKe@`6Sj@R+I|4ymD7ip;%QO`LwpI zBo#~Q84CZPlH~DuLXt#LGbV?pd^=UQ;te|D2LJmLa>;caq<;t#uJU&!O46k9HP3>)3Db%wO69M)hNb!+n|UU$&p zR$Qtkhg2glVSDl48l~puOQ`%Hb@wEG-4kb((m6_y7LrI~kbxt!j z$?@UYGFUGu%9PN+YV1JX(Ue7@o8lV`z!z*zF_u>Na zxp-P^6pMrlLZ9$T?1e|ddN?VR@?-otej8s;a~JuD(o&~GHq=gZPU9?EBAMEs)b+&X zG|kwDKcVN7NvEm!9}(Rd`rbSSviYua79tCBu>uClO~%EdrZh}Gx0={}YQAC>>R@~} zk|aMRUSS~$FqpXut;MXtc#y3HtR~xu|D~9dP?@Ijze3EeK`Qf2*@KxGW|@)92rP?4 ziA02%tO>UMiR+XEGb)k9MEMk*5nwtAIE)$bg^28zd=>U`4!*;CcmWUKHVnXV*bBR% z4_aUuRDu!o_$z))|L8OL5ccB@xCWQg497%L(Vq=>4|cz}_GZbE*!|j`8~3_NZNFM| zpmy@mh&9uYuFJuO+_`#suF?P6$RDJeSN1pV% y=6w7xAx)0ziv=pM3tRT}-idyzjNRHEd;Mj6eC)Q;3M{p)3oEk~_J~?P$ZfJGUteBkzvf&MWt zp*5uFD0CR{C=2O=X{jM7Lv!)2%37YUs zFUgr z$fwE^4F$61L{cm(5Qu z>m$eIjC{GFAUA3wZtS;_nLb6TpaIpPR20P3R$Dvv1ROTE-)ScXbINAm*f3QRJ~^7Q z)rnmWpVfiwlGo;!yphiO&CW47QZGh7NN0-a-d~3|mDtWQ9%`LX$0MYkr0oAPN znQ^_v7I`LfQY=zVTu?r8!a%M=n+%LXz4k6l*WXwkq*`iH5yW3|GL`OX#4Hcm6d_*` z;zcUXOeG+g1r88hUT@9;>X`q7O^L?aiEDDKQ;{-hXR`VVK)?!Ip~|wIt*#6Zv;vo{ za>VFr)yZQ0Rj-pIpSI>@DP`O6CYAs~H+sL|u;?~e3>5l}zM&=b4!uT`=mENkuAx38 zkY9ue<)aL^20y~5@Ch7+Uf2Yyp@AU&%F0E9dUUsXWhyY=bMVi7X|;D`MXnl@ELGCd zf#)`DMxrJeW~A&+C+pGlPp{f$p1lk$EV~}h_sy(?zK^eco(^yHt%Vl5ugD(;?S`4< r@UO5@!!g8P&UD}DeKYjsr?egtP4jQv%Zr(^FmxWc=Rt6C=-8pZkTtJf diff --git a/.gitignore b/.gitignore index a9b621f..22d9e75 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ p38venv .pytest_cache *.pyc *.log -test.db \ No newline at end of file +test.db +dist \ No newline at end of file diff --git a/README.md b/README.md index 3186786..d72c80b 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,8 @@ Coverage - -CodeFactor + +CodeFactor Codacy @@ -35,22 +35,22 @@ The encode package was too simple (i.e. no ability to join two times to the same ```python import databases -import orm +import ormar import sqlalchemy database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Note(orm.Model): +class Note(ormar.Model): __tablename__ = "notes" __database__ = database __metadata__ = metadata # primary keys of type int by dafault are set to autoincrement - id = orm.Integer(primary_key=True) - text = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + text = ormar.String(length=100) + completed = ormar.Boolean(default=False) # Create the database engine = sqlalchemy.create_engine(str(database.url)) @@ -88,31 +88,31 @@ ORM supports loading and filtering across foreign keys... ```python import databases -import orm +import ormar import sqlalchemy database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Album(orm.Model): +class Album(ormar.Model): __tablename__ = "album" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Track(orm.Model): +class Track(ormar.Model): __tablename__ = "track" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) - title = orm.String(length=100) - position = orm.Integer() + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() # Create some records to work with. diff --git a/docs/fields.md b/docs/fields.md index eed37dc..66d4dd7 100644 --- a/docs/fields.md +++ b/docs/fields.md @@ -199,7 +199,7 @@ But you can overwrite this name by providing `related_name` parameter like below Since related models are coming from Relationship Manager the reverse relation on access returns list of `wekref.proxy` to avoid circular references. !!!info - All relations are stored in lists, but when you access parent `Model` the ORM is unpacking the value for you. + All relations are stored in lists, but when you access parent `Model` the ormar is unpacking the value for you. Read more in [relations][relations]. [relations]: ./relations.md diff --git a/docs/index.md b/docs/index.md index 24756fe..cf66bba 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,4 @@ -# Async-ORM +# ORMar

@@ -15,42 +15,42 @@

-The `async-orm` package is an async ORM for Python, with support for Postgres, -MySQL, and SQLite. ORM is built with: +The `ormar` package is an async ORM for Python, with support for Postgres, +MySQL, and SQLite. Ormar is built with: * [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. -Because ORM is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide +Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide database migrations. -The goal was to create a simple orm that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. +The goal was to create a simple ORM that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. Initial work was inspired by [`encode/orm`][encode/orm]. The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. -**async-orm is still under development:** We recommend pinning any dependencies with `aorm~=0.0.1` +**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.0.1` **Note**: Use `ipython` to try this from the console, since it supports `await`. ```python import databases -import orm +import ormar import sqlalchemy database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Note(orm.Model): +class Note(ormar.Model): __tablename__ = "notes" __database__ = database __metadata__ = metadata # primary keys of type int by dafault are set to autoincrement - id = orm.Integer(primary_key=True) - text = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + text = ormar.String(length=100) + completed = ormar.Boolean(default=False) # Create the database engine = sqlalchemy.create_engine(str(database.url)) @@ -84,35 +84,35 @@ note = await Note.objects.get(pk=2) note.pk # 2 ``` -ORM supports loading and filtering across foreign keys... +Ormar supports loading and filtering across foreign keys... ```python import databases -import orm +import ormar import sqlalchemy database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Album(orm.Model): +class Album(ormar.Model): __tablename__ = "album" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Track(orm.Model): +class Track(ormar.Model): __tablename__ = "track" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) - title = orm.String(length=100) - position = orm.Integer() + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() # Create some records to work with. @@ -187,17 +187,17 @@ All fields are required unless one of the following is set: ### Fields Types -* `orm.String(length)` -* `orm.Text()` -* `orm.Boolean()` -* `orm.Integer()` -* `orm.Float()` -* `orm.Date()` -* `orm.Time()` -* `orm.DateTime()` -* `orm.JSON()` -* `orm.BigInteger()` -* `orm.Decimal(lenght, precision)` +* `String(length)` +* `Text()` +* `Boolean()` +* `Integer()` +* `Float()` +* `Date()` +* `Time()` +* `DateTime()` +* `JSON()` +* `BigInteger()` +* `Decimal(lenght, precision)` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases diff --git a/docs/models.md b/docs/models.md index e4f5a13..5fd3cd6 100644 --- a/docs/models.md +++ b/docs/models.md @@ -1,17 +1,20 @@ # Models ## Defining models -By defining an orm Model you get corresponding **Pydantic model** as well as **Sqlalchemy table** for free. + +By defining an ormar Model you get corresponding **Pydantic model** as well as **Sqlalchemy table** for free. They are being managed in the background and you do not have to create them on your own. ### Model Class -To build an ORM model you simply need to inherit a `orm.Model` class. + +To build an ormar model you simply need to inherit a `ormar.Model` class. ```Python hl_lines="10" --8<-- "../docs_src/models/docs001.py" ``` ### Defining Fields + Next assign one or more of the [Fields][fields] as a class level variables. Each table **has to** have a primary key column, which you specify by setting `primary_key=True` on selected field. @@ -31,17 +34,18 @@ By default if you assign primary key to `Integer` field, the `autoincrement` opt You can disable by passing `autoincremant=False`. ```Python -id = orm.Integer(primary_key=True, autoincrement=False) +id = ormar.Integer(primary_key=True, autoincrement=False) ``` Names of the fields will be used for both the underlying `pydantic` model and `sqlalchemy` table. ### Dependencies -Since orm depends on [`databases`][databases] and [`sqlalchemy-core`][sqlalchemy-core] for database connection +Since ormar depends on [`databases`][databases] and [`sqlalchemy-core`][sqlalchemy-core] for database connection and table creation you need to assign each `Model` with two special parameters. #### Databases + One is `Database` instance created with your database url in [sqlalchemy connection string][sqlalchemy connection string] format. Created instance needs to be passed to every `Model` with `__database__` parameter. @@ -55,6 +59,7 @@ Created instance needs to be passed to every `Model` with `__database__` paramet You can create several ones if you want to use multiple databases. #### Sqlalchemy + Second dependency is sqlalchemy `MetaData` instance. Created instance needs to be passed to every `Model` with `__metadata__` parameter. @@ -123,6 +128,7 @@ The fields and relations are not stored on the `Model` itself Apart from special parameters defined in the `Model` during definition (tablename, metadata etc.) the `Model` provides you with useful internals. ### Pydantic Model + To access auto created pydantic model you can use `Model.__pydantic_model__` parameter For example to list model fields you can: @@ -138,6 +144,7 @@ For example to list model fields you can: For more options visit official [pydantic][pydantic] documentation. ### Sqlalchemy Table + To access auto created sqlalchemy table you can use `Model.__table__` parameter For example to list table columns you can: @@ -153,7 +160,8 @@ For example to list table columns you can: For more options visit official [sqlalchemy-metadata][sqlalchemy-metadata] documentation. ### Fields Definition -To access orm `Fields` you can use `Model.__model_fields__` parameter + +To access ormar `Fields` you can use `Model.__model_fields__` parameter For example to list table model fields you can: diff --git a/docs/queries.md b/docs/queries.md index b931809..474a98e 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -30,7 +30,7 @@ await malibu.save() ### load() -By default when you query a table without prefetching related models, the orm will still construct +By default when you query a table without prefetching related models, the ormar will still construct your related models, but populate them only with the pk value. ```python diff --git a/docs/relations.md b/docs/relations.md index c671252..43a6e20 100644 --- a/docs/relations.md +++ b/docs/relations.md @@ -74,7 +74,7 @@ classes = await SchoolClass.objects.select_related( Since you join two times to the same table (categories) it won't work by default -> you would need to use aliases for category tables and columns. -But don't worry - orm can handle situations like this, as it uses the Relationship Manager which has it's aliases defined for all relationships. +But don't worry - ormar can handle situations like this, as it uses the Relationship Manager which has it's aliases defined for all relationships. Each class is registered with the same instance of the RelationshipManager that you can access like this: @@ -110,7 +110,7 @@ print(Teacher._orm_relationship_manager.resolve_relation_join( ### Query automatic construction -Orm is using those aliases during queries to both construct a meaningful and valid sql, +Ormar is using those aliases during queries to both construct a meaningful and valid sql, as well as later use it to extract proper columns for proper nested models. Running a previously mentioned query to select school classes and related teachers and students: diff --git a/docs_src/fields/docs001.py b/docs_src/fields/docs001.py index f28b9d0..047690d 100644 --- a/docs_src/fields/docs001.py +++ b/docs_src/fields/docs001.py @@ -1,28 +1,28 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Department(orm.Model): +class Department(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) - department = orm.ForeignKey(Department) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + department = ormar.ForeignKey(Department) department = Department(name='Science') diff --git a/docs_src/fields/docs002.py b/docs_src/fields/docs002.py index 5c9f4a9..7fa6ccd 100644 --- a/docs_src/fields/docs002.py +++ b/docs_src/fields/docs002.py @@ -1,28 +1,28 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Department(orm.Model): +class Department(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) - department = orm.ForeignKey(Department, related_name="my_courses") + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + department = ormar.ForeignKey(Department, related_name="my_courses") department = Department(name='Science') course = Course(name='Math', completed=False, department=department) diff --git a/docs_src/fields/docs003.py b/docs_src/fields/docs003.py index c3df501..32ff68c 100644 --- a/docs_src/fields/docs003.py +++ b/docs_src/fields/docs003.py @@ -1,4 +1,4 @@ -import orm +import ormar import databases import sqlalchemy @@ -6,24 +6,24 @@ database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Album(orm.Model): +class Album(ormar.Model): __tablename__ = "album" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Track(orm.Model): +class Track(ormar.Model): __tablename__ = "track" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album, nullable=False) - title = orm.String(length=100) - position = orm.Integer() + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album, nullable=False) + title = ormar.String(length=100) + position = ormar.Integer() album = await Album.objects.create(name="Brooklyn") diff --git a/docs_src/models/docs001.py b/docs_src/models/docs001.py index 6aba0ed..9c8f8f1 100644 --- a/docs_src/models/docs001.py +++ b/docs_src/models/docs001.py @@ -1,16 +1,16 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) diff --git a/docs_src/models/docs002.py b/docs_src/models/docs002.py index 0886dcc..1d63371 100644 --- a/docs_src/models/docs002.py +++ b/docs_src/models/docs002.py @@ -1,19 +1,19 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Course(orm.Model): +class Course(ormar.Model): # if you omit this parameter it will be created automatically # as class.__name__.lower()+'s' -> "courses" in this example __tablename__ = "my_courses" __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) diff --git a/docs_src/models/docs003.py b/docs_src/models/docs003.py index bd108d4..754f6d4 100644 --- a/docs_src/models/docs003.py +++ b/docs_src/models/docs003.py @@ -1,19 +1,19 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) print(Course.__pydantic_model__.__fields__) """ diff --git a/docs_src/models/docs004.py b/docs_src/models/docs004.py index f2f06ba..36e5da0 100644 --- a/docs_src/models/docs004.py +++ b/docs_src/models/docs004.py @@ -1,19 +1,19 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) print(Course.__table__.columns) """ diff --git a/docs_src/models/docs005.py b/docs_src/models/docs005.py index cf33c9d..e1a85e8 100644 --- a/docs_src/models/docs005.py +++ b/docs_src/models/docs005.py @@ -1,19 +1,19 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) print(Course.__model_fields__) """ diff --git a/docs_src/models/docs006.py b/docs_src/models/docs006.py index fc0fbef..b232979 100644 --- a/docs_src/models/docs006.py +++ b/docs_src/models/docs006.py @@ -1,28 +1,28 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Department(orm.Model): +class Department(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) - department = orm.ForeignKey(Department) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) + department = ormar.ForeignKey(Department) department = Department(name="Science") diff --git a/docs_src/models/docs007.py b/docs_src/models/docs007.py index 2bd7af1..f98e62a 100644 --- a/docs_src/models/docs007.py +++ b/docs_src/models/docs007.py @@ -1,19 +1,19 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Course(orm.Model): +class Course(ormar.Model): __database__ = database __metadata__ = metadata - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - completed = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + completed = ormar.Boolean(default=False) course = Course(name="Painting for dummies", completed=False) diff --git a/docs_src/relations/docs001.py b/docs_src/relations/docs001.py index 4a028c3..53e2f5c 100644 --- a/docs_src/relations/docs001.py +++ b/docs_src/relations/docs001.py @@ -1,4 +1,4 @@ -import orm +import ormar import databases import sqlalchemy @@ -6,21 +6,21 @@ database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Album(orm.Model): +class Album(ormar.Model): __tablename__ = "album" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Track(orm.Model): +class Track(ormar.Model): __tablename__ = "track" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) - title = orm.String(length=100) - position = orm.Integer() \ No newline at end of file + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() \ No newline at end of file diff --git a/docs_src/relations/docs002.py b/docs_src/relations/docs002.py index 9e3dfda..ef67093 100644 --- a/docs_src/relations/docs002.py +++ b/docs_src/relations/docs002.py @@ -1,4 +1,4 @@ -import orm +import ormar import databases import sqlalchemy @@ -6,24 +6,24 @@ database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class Album(orm.Model): +class Album(ormar.Model): __tablename__ = "album" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Track(orm.Model): +class Track(ormar.Model): __tablename__ = "track" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) - title = orm.String(length=100) - position = orm.Integer() + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() print(Track.__table__.columns['album'].__repr__()) diff --git a/docs_src/relations/docs003.py b/docs_src/relations/docs003.py index aa54c29..e319e42 100644 --- a/docs_src/relations/docs003.py +++ b/docs_src/relations/docs003.py @@ -1,44 +1,44 @@ import databases import sqlalchemy -import orm +import ormar database = databases.Database("sqlite:///db.sqlite") metadata = sqlalchemy.MetaData() -class SchoolClass(orm.Model): +class SchoolClass(ormar.Model): __tablename__ = "schoolclasses" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Category(orm.Model): +class Category(ormar.Model): __tablename__ = "categories" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Student(orm.Model): +class Student(ormar.Model): __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - schoolclass = orm.ForeignKey(SchoolClass) - category = orm.ForeignKey(Category) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + schoolclass = ormar.ForeignKey(SchoolClass) + category = ormar.ForeignKey(Category) -class Teacher(orm.Model): +class Teacher(ormar.Model): __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - schoolclass = orm.ForeignKey(SchoolClass) - category = orm.ForeignKey(Category) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + schoolclass = ormar.ForeignKey(SchoolClass) + category = ormar.ForeignKey(Category) diff --git a/orm/models/__init__.py b/orm/models/__init__.py deleted file mode 100644 index b948f31..0000000 --- a/orm/models/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from orm.models.fakepydantic import FakePydantic -from orm.models.model import Model - -__all__ = ["FakePydantic", "Model"] diff --git a/orm/queryset/__init__.py b/orm/queryset/__init__.py deleted file mode 100644 index 30e112a..0000000 --- a/orm/queryset/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from orm.queryset.queryset import QuerySet - -__all__ = ["QuerySet"] diff --git a/orm/__init__.py b/ormar/__init__.py similarity index 72% rename from orm/__init__.py rename to ormar/__init__.py index 2cf9401..1d9f65c 100644 --- a/orm/__init__.py +++ b/ormar/__init__.py @@ -1,5 +1,5 @@ -from orm.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch -from orm.fields import ( +from ormar.exceptions import ModelDefinitionError, ModelNotSet, MultipleMatches, NoMatch +from ormar.fields import ( BigInteger, Boolean, Date, @@ -13,9 +13,9 @@ from orm.fields import ( Text, Time, ) -from orm.models import Model +from ormar.models import Model -__version__ = "0.0.1" +__version__ = "0.1.0" __all__ = [ "Integer", "BigInteger", diff --git a/orm/exceptions.py b/ormar/exceptions.py similarity index 100% rename from orm/exceptions.py rename to ormar/exceptions.py diff --git a/orm/fields/__init__.py b/ormar/fields/__init__.py similarity index 72% rename from orm/fields/__init__.py rename to ormar/fields/__init__.py index 6355c38..f6c4dc9 100644 --- a/orm/fields/__init__.py +++ b/ormar/fields/__init__.py @@ -1,6 +1,6 @@ -from orm.fields.base import BaseField -from orm.fields.foreign_key import ForeignKey -from orm.fields.model_fields import ( +from ormar.fields.base import BaseField +from ormar.fields.foreign_key import ForeignKey +from ormar.fields.model_fields import ( BigInteger, Boolean, Date, diff --git a/orm/fields/base.py b/ormar/fields/base.py similarity index 96% rename from orm/fields/base.py rename to ormar/fields/base.py index 003d14e..cef323f 100644 --- a/orm/fields/base.py +++ b/ormar/fields/base.py @@ -2,10 +2,10 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING import sqlalchemy -from orm import ModelDefinitionError # noqa I101 +from ormar import ModelDefinitionError # noqa I101 if TYPE_CHECKING: # pragma no cover - from orm.models import Model + from ormar.models import Model class BaseField: diff --git a/orm/fields/decorators.py b/ormar/fields/decorators.py similarity index 91% rename from orm/fields/decorators.py rename to ormar/fields/decorators.py index ae4e498..842e864 100644 --- a/orm/fields/decorators.py +++ b/ormar/fields/decorators.py @@ -1,9 +1,9 @@ from typing import Any, TYPE_CHECKING, Type -from orm import ModelDefinitionError +from ormar import ModelDefinitionError if TYPE_CHECKING: # pragma no cover - from orm.fields import BaseField + from ormar.fields import BaseField class RequiredParams: diff --git a/orm/fields/foreign_key.py b/ormar/fields/foreign_key.py similarity index 95% rename from orm/fields/foreign_key.py rename to ormar/fields/foreign_key.py index 839ca95..77c2da7 100644 --- a/orm/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -3,12 +3,12 @@ from typing import Any, List, Optional, TYPE_CHECKING, Type, Union import sqlalchemy from pydantic import BaseModel -import orm # noqa I101 -from orm.exceptions import RelationshipInstanceError -from orm.fields.base import BaseField +import ormar # noqa I101 +from ormar.exceptions import RelationshipInstanceError +from ormar.fields.base import BaseField if TYPE_CHECKING: # pragma no cover - from orm.models import Model + from ormar.models import Model def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": diff --git a/orm/fields/model_fields.py b/ormar/fields/model_fields.py similarity index 94% rename from orm/fields/model_fields.py rename to ormar/fields/model_fields.py index f14391e..4f9be11 100644 --- a/orm/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -4,8 +4,8 @@ import decimal import sqlalchemy from pydantic import Json -from orm.fields.base import BaseField # noqa I101 -from orm.fields.decorators import RequiredParams +from ormar.fields.base import BaseField # noqa I101 +from ormar.fields.decorators import RequiredParams @RequiredParams("length") diff --git a/ormar/models/__init__.py b/ormar/models/__init__.py new file mode 100644 index 0000000..b70c515 --- /dev/null +++ b/ormar/models/__init__.py @@ -0,0 +1,4 @@ +from ormar.models.fakepydantic import FakePydantic +from ormar.models.model import Model + +__all__ = ["FakePydantic", "Model"] diff --git a/orm/models/fakepydantic.py b/ormar/models/fakepydantic.py similarity index 95% rename from orm/models/fakepydantic.py rename to ormar/models/fakepydantic.py index 5e1d6b4..c7a1bc8 100644 --- a/orm/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -19,13 +19,13 @@ import pydantic import sqlalchemy from pydantic import BaseModel -import orm # noqa I100 -from orm.fields import BaseField -from orm.models.metaclass import ModelMetaclass -from orm.relations import RelationshipManager +import ormar # noqa I100 +from ormar.fields import BaseField +from ormar.models.metaclass import ModelMetaclass +from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover - from orm.models.model import Model + from ormar.models.model import Model class FakePydantic(list, metaclass=ModelMetaclass): @@ -194,10 +194,10 @@ class FakePydantic(list, metaclass=ModelMetaclass): def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": for field in one.__model_fields__.keys(): if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), orm.Model + getattr(one, field), ormar.Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) - elif isinstance(getattr(one, field), orm.Model): + elif isinstance(getattr(one, field), ormar.Model): if getattr(one, field).pk == getattr(other, field).pk: setattr( other, diff --git a/orm/models/metaclass.py b/ormar/models/metaclass.py similarity index 96% rename from orm/models/metaclass.py rename to ormar/models/metaclass.py index f519791..7ea2fe7 100644 --- a/orm/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -5,12 +5,12 @@ import sqlalchemy from pydantic import BaseConfig, create_model from pydantic.fields import ModelField -from orm import ForeignKey, ModelDefinitionError # noqa I100 -from orm.fields import BaseField -from orm.relations import RelationshipManager +from ormar import ForeignKey, ModelDefinitionError # noqa I100 +from ormar.fields import BaseField +from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover - from orm import Model + from ormar import Model relationship_manager = RelationshipManager() diff --git a/orm/models/model.py b/ormar/models/model.py similarity index 95% rename from orm/models/model.py rename to ormar/models/model.py index ab21bfc..fed4224 100644 --- a/orm/models/model.py +++ b/ormar/models/model.py @@ -2,14 +2,14 @@ from typing import Any, List import sqlalchemy -import orm.queryset # noqa I100 -from orm.models import FakePydantic # noqa I100 +import ormar.queryset # noqa I100 +from ormar.models import FakePydantic # noqa I100 class Model(FakePydantic): __abstract__ = True - objects = orm.queryset.QuerySet() + objects = ormar.queryset.QuerySet() @classmethod def from_row( diff --git a/ormar/queryset/__init__.py b/ormar/queryset/__init__.py new file mode 100644 index 0000000..7bf6fc6 --- /dev/null +++ b/ormar/queryset/__init__.py @@ -0,0 +1,3 @@ +from ormar.queryset.queryset import QuerySet + +__all__ = ["QuerySet"] diff --git a/orm/queryset/clause.py b/ormar/queryset/clause.py similarity index 96% rename from orm/queryset/clause.py rename to ormar/queryset/clause.py index e70aaa1..3d5d14b 100644 --- a/orm/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union import sqlalchemy from sqlalchemy import text -import orm # noqa I100 -from orm.exceptions import QueryDefinitionError +import ormar # noqa I100 +from ormar.exceptions import QueryDefinitionError if TYPE_CHECKING: # pragma no cover - from orm import Model + from ormar import Model FILTER_OPERATORS = { "exact": "__eq__", @@ -75,7 +75,7 @@ class QueryClause: value, has_escaped_character = self._escape_characters_in_clause(op, value) - if isinstance(value, orm.Model): + if isinstance(value, ormar.Model): value = value.pk op_attr = FILTER_OPERATORS[op] @@ -147,7 +147,7 @@ class QueryClause: if op not in ["contains", "icontains"]: return value, has_escaped_character - if isinstance(value, orm.Model): + if isinstance(value, ormar.Model): raise QueryDefinitionError( "You cannot use contains and icontains with instance of the Model" ) diff --git a/orm/queryset/query.py b/ormar/queryset/query.py similarity index 97% rename from orm/queryset/query.py rename to ormar/queryset/query.py index bf6234f..202b249 100644 --- a/orm/queryset/query.py +++ b/ormar/queryset/query.py @@ -3,12 +3,12 @@ from typing import List, NamedTuple, TYPE_CHECKING, Tuple, Type import sqlalchemy from sqlalchemy import text -import orm # noqa I100 -from orm import ForeignKey -from orm.fields import BaseField +import ormar # noqa I100 +from ormar import ForeignKey +from ormar.fields import BaseField if TYPE_CHECKING: # pragma no cover - from orm import Model + from ormar import Model class JoinParameters(NamedTuple): @@ -53,7 +53,7 @@ class Query: if ( not self.model_cls.__model_fields__[key].nullable and isinstance( - self.model_cls.__model_fields__[key], orm.fields.ForeignKey, + self.model_cls.__model_fields__[key], ormar.fields.ForeignKey, ) and key not in self._select_related ): diff --git a/orm/queryset/queryset.py b/ormar/queryset/queryset.py similarity index 96% rename from orm/queryset/queryset.py rename to ormar/queryset/queryset.py index 5cbc3cd..3190036 100644 --- a/orm/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -3,13 +3,13 @@ from typing import Any, List, TYPE_CHECKING, Tuple, Type, Union import databases import sqlalchemy -import orm # noqa I100 -from orm import MultipleMatches, NoMatch -from orm.queryset.clause import QueryClause -from orm.queryset.query import Query +import ormar # noqa I100 +from ormar import MultipleMatches, NoMatch +from ormar.queryset.clause import QueryClause +from ormar.queryset.query import Query if TYPE_CHECKING: # pragma no cover - from orm import Model + from ormar import Model class QuerySet: diff --git a/orm/relations.py b/ormar/relations.py similarity index 97% rename from orm/relations.py rename to ormar/relations.py index 9df7ee7..45c3ddd 100644 --- a/orm/relations.py +++ b/ormar/relations.py @@ -5,10 +5,10 @@ from random import choices from typing import List, TYPE_CHECKING, Union from weakref import proxy -from orm import ForeignKey +from ormar import ForeignKey if TYPE_CHECKING: # pragma no cover - from orm.models import FakePydantic, Model + from ormar.models import FakePydantic, Model def get_table_alias() -> str: diff --git a/scripts/clean.sh b/scripts/clean.sh old mode 100644 new mode 100755 index 38b8eb7..9b30e14 --- a/scripts/clean.sh +++ b/scripts/clean.sh @@ -1,5 +1,5 @@ #!/bin/sh -e -PACKAGE="orm" +PACKAGE="ormar" if [ -d 'dist' ] ; then rm -r dist fi diff --git a/scripts/publish.sh b/scripts/publish.sh new file mode 100644 index 0000000..419fa30 --- /dev/null +++ b/scripts/publish.sh @@ -0,0 +1,23 @@ +#!/bin/sh -e + +PACKAGE="ormar" + +PREFIX="" +if [ -d 'venv' ] ; then + PREFIX="venv/bin/" +fi + +VERSION=`cat ${PACKAGE}/__init__.py | grep __version__ | sed "s/__version__ = //" | sed "s/'//g"` + +set -x + +scripts/clean.sh + +${PREFIX}python setup.py sdist +${PREFIX}twine upload dist/* + +echo "You probably want to also tag the version now:" +echo "git tag -a ${VERSION} -m 'version ${VERSION}'" +echo "git push --tags" + +scripts/clean.sh \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh index 11c09f6..c911ffb 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -1,6 +1,6 @@ #!/bin/sh -e -PACKAGE="orm" +PACKAGE="ormar" PREFIX="" if [ -d 'venv' ] ; then diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..224a779 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[metadata] +description-file = README.md \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4b88c2d --- /dev/null +++ b/setup.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import re + +from setuptools import setup + +PACKAGE = "ormar" +URL = "https://github.com/collerek/ormar" + + +def get_version(package): + """ + Return package version as listed in `__version__` in `init.py`. + """ + with open(os.path.join(package, "__init__.py")) as f: + return re.search("__version__ = ['\"]([^'\"]+)['\"]", f.read()).group(1) + + +def get_long_description(): + """ + Return the README. + """ + with open("README.md", encoding="utf8") as f: + return f.read() + + +def get_packages(package): + """ + Return root package and all sub-packages. + """ + return [ + dirpath + for dirpath, dirnames, filenames in os.walk(package) + if os.path.exists(os.path.join(dirpath, "__init__.py")) + ] + + +setup( + name=PACKAGE, + version=get_version(PACKAGE), + url=URL, + license="MIT", + description="An simple async ORM with Fastapi in mind.", + long_description=get_long_description(), + long_description_content_type="text/markdown", + keywords=['ORM', 'sqlalchemy', 'fastapi', 'pydantic', 'databases'], + author="Radosław Drążkiewicz", + author_email="collerek@gmail.com", + packages=get_packages(PACKAGE), + package_data={PACKAGE: ["py.typed"]}, + data_files=[("", ["LICENSE.md"])], + install_requires=["databases", "pydantic", "sqlalchemy"], + classifiers=[ + "Development Status :: 3 - Alpha", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Topic :: Internet :: WWW/HTTP", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + ], +) diff --git a/tests/test_columns.py b/tests/test_columns.py index edee116..e75cb0a 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -4,7 +4,7 @@ import databases import pytest import sqlalchemy -import orm +import ormar from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) @@ -15,18 +15,18 @@ def time(): return datetime.datetime.now().time() -class Example(orm.Model): +class Example(ormar.Model): __tablename__ = "example" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - created = orm.DateTime(default=datetime.datetime.now) - created_day = orm.Date(default=datetime.date.today) - created_time = orm.Time(default=time) - description = orm.Text(nullable=True) - value = orm.Float(nullable=True) - data = orm.JSON(default={}) + id = ormar.Integer(primary_key=True) + created = ormar.DateTime(default=datetime.datetime.now) + created_day = ormar.Date(default=datetime.date.today) + created_time = ormar.Time(default=time) + description = ormar.Text(nullable=True) + value = ormar.Float(nullable=True) + data = ormar.JSON(default={}) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index bef1607..25b1310 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -3,8 +3,7 @@ import sqlalchemy from fastapi import FastAPI from fastapi.testclient import TestClient -import orm -import orm.fields.foreign_key +import ormar from tests.settings import DATABASE_URL app = FastAPI() @@ -13,23 +12,23 @@ database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() -class Category(orm.Model): +class Category(ormar.Model): __tablename__ = "categories" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Item(orm.Model): +class Item(ormar.Model): __tablename__ = "items" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - category = orm.fields.foreign_key.ForeignKey(Category, nullable=True) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + category = ormar.ForeignKey(Category, nullable=True) @app.post("/items/", response_model=Item) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index aaea6ff..45f4359 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -2,71 +2,71 @@ import databases import pytest import sqlalchemy -import orm -from orm.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError +import ormar +from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() -class Album(orm.Model): +class Album(ormar.Model): __tablename__ = "album" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Track(orm.Model): +class Track(ormar.Model): __tablename__ = "track" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album) - title = orm.String(length=100) - position = orm.Integer() + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album) + title = ormar.String(length=100) + position = ormar.Integer() -class Cover(orm.Model): +class Cover(ormar.Model): __tablename__ = "covers" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - album = orm.ForeignKey(Album, related_name='cover_pictures') - title = orm.String(length=100) + id = ormar.Integer(primary_key=True) + album = ormar.ForeignKey(Album, related_name='cover_pictures') + title = ormar.String(length=100) -class Organisation(orm.Model): +class Organisation(ormar.Model): __tablename__ = "org" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - ident = orm.String(length=100) + id = ormar.Integer(primary_key=True) + ident = ormar.String(length=100) -class Team(orm.Model): +class Team(ormar.Model): __tablename__ = "team" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - org = orm.ForeignKey(Organisation) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + org = ormar.ForeignKey(Organisation) + name = ormar.String(length=100) -class Member(orm.Model): +class Member(ormar.Model): __tablename__ = "member" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - team = orm.ForeignKey(Team) - email = orm.String(length=100) + id = ormar.Integer(primary_key=True) + team = ormar.ForeignKey(Team) + email = ormar.String(length=100) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index 8ea9289..253ae4f 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -4,9 +4,9 @@ import pydantic import pytest import sqlalchemy -import orm.fields as fields -from orm.exceptions import ModelDefinitionError -from orm.models import Model +import ormar.fields as fields +from ormar.exceptions import ModelDefinitionError +from ormar.models import Model metadata = sqlalchemy.MetaData() diff --git a/tests/test_models.py b/tests/test_models.py index 8b0bf37..69d421c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,32 +2,32 @@ import databases import pytest import sqlalchemy -import orm -from orm.exceptions import QueryDefinitionError +import ormar +from ormar.exceptions import QueryDefinitionError from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() -class User(orm.Model): +class User(ormar.Model): __tablename__ = "users" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Product(orm.Model): +class Product(ormar.Model): __tablename__ = "product" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - rating = orm.Integer(minimum=1, maximum=5) - in_stock = orm.Boolean(default=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + rating = ormar.Integer(minimum=1, maximum=5) + in_stock = ormar.Boolean(default=False) @pytest.fixture(autouse=True, scope="module") @@ -40,9 +40,9 @@ def create_test_database(): def test_model_class(): assert list(User.__model_fields__.keys()) == ["id", "name"] - assert isinstance(User.__model_fields__["id"], orm.Integer) + assert isinstance(User.__model_fields__["id"], ormar.Integer) assert User.__model_fields__["id"].primary_key is True - assert isinstance(User.__model_fields__["name"], orm.String) + assert isinstance(User.__model_fields__["name"], ormar.String) assert User.__model_fields__["name"].length == 100 assert isinstance(User.__table__, sqlalchemy.Table) @@ -82,7 +82,7 @@ async def test_model_crud(): @pytest.mark.asyncio async def test_model_get(): async with database: - with pytest.raises(orm.NoMatch): + with pytest.raises(ormar.NoMatch): await User.objects.get() user = await User.objects.create(name="Tom") @@ -90,7 +90,7 @@ async def test_model_get(): assert lookup == user user = await User.objects.create(name="Jane") - with pytest.raises(orm.MultipleMatches): + with pytest.raises(ormar.MultipleMatches): await User.objects.get() same_user = await User.objects.get(pk=user.id) @@ -108,7 +108,7 @@ async def test_model_filter(): user = await User.objects.get(name="Lucy") assert user.name == "Lucy" - with pytest.raises(orm.NoMatch): + with pytest.raises(ormar.NoMatch): await User.objects.get(name="Jim") await Product.objects.create(name="T-Shirt", rating=5, in_stock=True) diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index b1dd02a..69f5dd2 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -4,62 +4,61 @@ import databases import pytest import sqlalchemy -import orm -import orm.fields.foreign_key +import ormar from tests.settings import DATABASE_URL database = databases.Database(DATABASE_URL, force_rollback=True) metadata = sqlalchemy.MetaData() -class Department(orm.Model): +class Department(ormar.Model): __tablename__ = "departments" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True, autoincrement=False) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True, autoincrement=False) + name = ormar.String(length=100) -class SchoolClass(orm.Model): +class SchoolClass(ormar.Model): __tablename__ = "schoolclasses" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - department = orm.fields.foreign_key.ForeignKey(Department, nullable=False) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + department = ormar.ForeignKey(Department, nullable=False) -class Category(orm.Model): +class Category(ormar.Model): __tablename__ = "categories" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) -class Student(orm.Model): +class Student(ormar.Model): __tablename__ = "students" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - schoolclass = orm.fields.foreign_key.ForeignKey(SchoolClass) - category = orm.fields.foreign_key.ForeignKey(Category, nullable=True) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + schoolclass = ormar.ForeignKey(SchoolClass) + category = ormar.ForeignKey(Category, nullable=True) -class Teacher(orm.Model): +class Teacher(ormar.Model): __tablename__ = "teachers" __metadata__ = metadata __database__ = database - id = orm.Integer(primary_key=True) - name = orm.String(length=100) - schoolclass = orm.fields.foreign_key.ForeignKey(SchoolClass) - category = orm.fields.foreign_key.ForeignKey(Category, nullable=True) + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + schoolclass = ormar.ForeignKey(SchoolClass) + category = ormar.ForeignKey(Category, nullable=True) @pytest.fixture(scope='module') From b3cc2ba86b517f60114833f629e888dffacd81ce Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 14 Aug 2020 19:39:47 +0200 Subject: [PATCH 58/62] renames in readme --- .gitignore | 3 ++- README.md | 38 +++++++++++++++++++------------------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/.gitignore b/.gitignore index 22d9e75..f58fa36 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ p38venv *.pyc *.log test.db -dist \ No newline at end of file +dist +/ormar.egg-info/ diff --git a/README.md b/README.md index d72c80b..0eca573 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@ -# Async-ORM +# ORMar

Build Status - Coverage + Coverage CodeFactor @@ -15,21 +15,21 @@

-The `async-orm` package is an async ORM for Python, with support for Postgres, -MySQL, and SQLite. ORM is built with: +The `ormar` package is an async ORM for Python, with support for Postgres, +MySQL, and SQLite. Ormar is built with: * [`SQLAlchemy core`][sqlalchemy-core] for query building. * [`databases`][databases] for cross-database async support. * [`pydantic`][pydantic] for data validation. -Because ORM is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide +Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide database migrations. -The goal was to create a simple orm that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. +The goal was to create a simple ormar that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. Initial work was inspired by [`encode/orm`][encode/orm]. The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. -**async-orm is still under development:** We recommend pinning any dependencies with `aorm~=0.0.1` +**ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.1.1` **Note**: Use `ipython` to try this from the console, since it supports `await`. @@ -84,7 +84,7 @@ note = await Note.objects.get(pk=2) note.pk # 2 ``` -ORM supports loading and filtering across foreign keys... +Ormar supports loading and filtering across foreign keys... ```python import databases @@ -181,17 +181,17 @@ All fields are required unless one of the following is set: Autoincrement is set by default on int primary keys. Available Model Fields: -* `orm.String(length)` -* `orm.Text()` -* `orm.Boolean()` -* `orm.Integer()` -* `orm.Float()` -* `orm.Date()` -* `orm.Time()` -* `orm.DateTime()` -* `orm.JSON()` -* `orm.BigInteger()` -* `orm.Decimal(lenght, precision)` +* `String(length)` +* `Text()` +* `Boolean()` +* `Integer()` +* `Float()` +* `Date()` +* `Time()` +* `DateTime()` +* `JSON()` +* `BigInteger()` +* `Decimal(lenght, precision)` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [databases]: https://github.com/encode/databases From 3232c99fcad694b3e23a8d52cd2e267ec3132094 Mon Sep 17 00:00:00 2001 From: collerek Date: Fri, 14 Aug 2020 19:40:09 +0200 Subject: [PATCH 59/62] bump version --- ormar/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ormar/__init__.py b/ormar/__init__.py index 1d9f65c..3d7e820 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -15,7 +15,7 @@ from ormar.fields import ( ) from ormar.models import Model -__version__ = "0.1.0" +__version__ = "0.1.1" __all__ = [ "Integer", "BigInteger", From a0ad85811bb21b7d0514c8ba3980602738d1a5c2 Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 15 Aug 2020 12:37:48 +0200 Subject: [PATCH 60/62] fix nested dicts, add more real life fastapi tests --- .coverage | Bin 53248 -> 53248 bytes README.md | 14 ++-- ormar/models/fakepydantic.py | 12 ++- ormar/models/model.py | 4 +- ormar/queryset/queryset.py | 13 +++- tests/test_foreign_keys.py | 18 ++--- tests/test_model_definition.py | 6 ++ tests/test_more_reallife_fastapi.py | 116 ++++++++++++++++++++++++++++ tests/test_same_table_joins.py | 4 +- 9 files changed, 161 insertions(+), 26 deletions(-) create mode 100644 tests/test_more_reallife_fastapi.py diff --git a/.coverage b/.coverage index 07f559488504b35b95fbb9a1f36d877952526add..bf048747f459554fda48d6da57377636b1799559 100644 GIT binary patch delta 387 zcmZozz}&Eac>|jQ7b^q51HV4MFhA>NL4mJ)llk-&luQ*2jjW8#tV~V$nV6-;%TkMq zGxPII^^Eik8*Ld`WSjCQ2k5Wo&{0s=)h$R#+Wb$SO@LXFn{6_?g8{QWHyhBTGh9sO zl9ShYt4z-F7G|}PU}mVFywO`(xF|KgC^a!BCo?TIJ}t4hB(WfKbCq`&3s5`zWcC09 zW(RI|p!Ty|lU<`NSR`$k>nArxt6Q%`c&+1cgy_B^+p|Ns2<2S2UuGm1(8#dRcpeXFs{zxT6p-t)?T zW!K*_PB!g3ASWXLl(Q1J)%|z>-M?>NKmWJq`}6a6-|qIG8eeC`K3S+gp6CCyt);hq zzujB+`yc1z+WvNCN&eu;>=O)`_=7h)PLShgl9b#WHGh=?P`r6E`vIVMGtfPr{H)B9 M%nZDfJI+@C0POvR%>V!Z delta 273 zcmZozz}&Eac>|jQ7as$^C%*;13_ssyL4kjKlUel@luQ&1Ev*d9txOE~nV6-;%TkMq zGxPII^^Eik8%-ElWSe3qyX&vt{9m6{fLVguWiq>i0iz30I}bOLxy0mk-YS!IyoEQ{ zdbhIxrClep2N*EA0;PGmC$mReY_5%NS7Damlbg)mpdk2_?>gTmzEym4_zL*K`Aqrb zfa;#{ii=9HFmmch>^eKU{NA4Dw)6j=H?P=tu6(jd-vL=^0ic|Pz^(4T`|tjJ`}+C6 zJ>Q?7zx#H#|I~Q9XOo5c<9YU%=Fa_ASH1rJzIvv~wf*hP68uq<*(VqR%~s;y95sKH T0#Kx5GW!9LNC5xjj`I}&5p`YW diff --git a/README.md b/README.md index 0eca573..c144d4f 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,16 @@ # ORMar

- - Build Status + + Build Status - - Coverage + + Coverage CodeFactor - + Codacy

@@ -25,10 +25,12 @@ MySQL, and SQLite. Ormar is built with: Because ormar is built on SQLAlchemy core, you can use [`alembic`][alembic] to provide database migrations. -The goal was to create a simple ormar that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. +The goal was to create a simple ORM that can be used directly with [`fastapi`][fastapi] that bases it's data validation on pydantic. Initial work was inspired by [`encode/orm`][encode/orm]. The encode package was too simple (i.e. no ability to join two times to the same table) and used typesystem for data checks. +To avoid too high coupling with pydantic and sqlalchemy ormar uses them by **composition** rather than by **inheritance**. + **ormar is still under development:** We recommend pinning any dependencies with `ormar~=0.1.1` **Note**: Use `ipython` to try this from the console, since it supports `await`. diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index c7a1bc8..d5109b4 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -117,15 +117,19 @@ class FakePydantic(list, metaclass=ModelMetaclass): def pk_type(cls) -> Any: return cls.__model_fields__[cls.__pkname__].__type__ - def dict(self) -> Dict: # noqa: A003 + def dict(self, nested=False) -> Dict: # noqa: A003 dict_instance = self.values.dict() for field in self._extract_related_names(): nested_model = getattr(self, field) - if isinstance(nested_model, list): - dict_instance[field] = [x.dict() for x in nested_model] + if self.__model_fields__[field].virtual and nested: + continue + if isinstance(nested_model, list) and not isinstance( + nested_model, ormar.Model + ): + dict_instance[field] = [x.dict(nested=True) for x in nested_model] else: dict_instance[field] = ( - nested_model.dict() if nested_model is not None else {} + nested_model.dict(nested=True) if nested_model is not None else {} ) return dict_instance diff --git a/ormar/models/model.py b/ormar/models/model.py index fed4224..b16d3e7 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -55,7 +55,7 @@ class Model(FakePydantic): def pk(self, value: Any) -> None: setattr(self.values, self.__pkname__, value) - async def save(self) -> int: + async def save(self) -> "Model": self_fields = self._extract_model_db_fields() if self.__model_fields__.get(self.__pkname__).autoincrement: self_fields.pop(self.__pkname__, None) @@ -63,7 +63,7 @@ class Model(FakePydantic): expr = expr.values(**self_fields) item_id = await self.__database__.execute(expr) self.pk = item_id - return item_id + return self async def update(self, **kwargs: Any) -> int: if kwargs: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 3190036..65192b8 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -160,10 +160,15 @@ class QuerySet: # substitute related models with their pk for field in self.model_cls._extract_related_names(): if field in new_kwargs and new_kwargs.get(field) is not None: - new_kwargs[field] = getattr( - new_kwargs.get(field), - self.model_cls.__model_fields__[field].to.__pkname__, - ) + if isinstance(new_kwargs.get(field), ormar.Model): + new_kwargs[field] = getattr( + new_kwargs.get(field), + self.model_cls.__model_fields__[field].to.__pkname__, + ) + else: + new_kwargs[field] = new_kwargs.get(field).get( + self.model_cls.__model_fields__[field].to.__pkname__ + ) # Build the insert expression. expr = self.table.insert() diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index 45f4359..bffa783 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -36,7 +36,7 @@ class Cover(ormar.Model): __database__ = database id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album, related_name='cover_pictures') + album = ormar.ForeignKey(Album, related_name="cover_pictures") title = ormar.String(length=100) @@ -171,8 +171,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -180,8 +180,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -223,8 +223,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -243,8 +243,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index 253ae4f..f6dc722 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -95,13 +95,16 @@ def test_sqlalchemy_table_is_created(example): def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata test_string = fields.String(length=250) + def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example3" __metadata__ = metadata @@ -111,6 +114,7 @@ def test_two_pks_in_model_definition(): def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata @@ -119,6 +123,7 @@ def test_setting_pk_column_as_pydantic_only_in_model_definition(): def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata @@ -127,6 +132,7 @@ def test_decimal_error_in_model_definition(): def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): + class ExampleModel2(Model): __tablename__ = "example4" __metadata__ = metadata diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py new file mode 100644 index 0000000..d7670cc --- /dev/null +++ b/tests/test_more_reallife_fastapi.py @@ -0,0 +1,116 @@ +from typing import List + +import databases +import pytest +import sqlalchemy +from fastapi import FastAPI +from starlette.testclient import TestClient + +import ormar +from tests.settings import DATABASE_URL + +app = FastAPI() +metadata = sqlalchemy.MetaData() +database = databases.Database(DATABASE_URL, force_rollback=True) +app.state.database = database + + +@app.on_event("startup") +async def startup() -> None: + database_ = app.state.database + if not database_.is_connected: + await database_.connect() + + +@app.on_event("shutdown") +async def shutdown() -> None: + database_ = app.state.database + if database_.is_connected: + await database_.disconnect() + + +class Category(ormar.Model): + __tablename__ = "categories" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + + +class Item(ormar.Model): + __tablename__ = "items" + __metadata__ = metadata + __database__ = database + + id = ormar.Integer(primary_key=True) + name = ormar.String(length=100) + category = ormar.ForeignKey(Category, nullable=True) + + +@pytest.fixture(autouse=True, scope="module") +def create_test_database(): + engine = sqlalchemy.create_engine(DATABASE_URL) + metadata.create_all(engine) + yield + metadata.drop_all(engine) + + +@app.get("/items/", response_model=List[Item]) +async def get_items(): + items = await Item.objects.select_related("category").all() + return [item.dict() for item in items] + + +@app.post("/items/", response_model=Item) +async def create_item(item: Item): + item = await Item.objects.create(**item.dict()) + return item.dict() + + +@app.post("/categories/", response_model=Category) +async def create_category(category: Category): + category = await Category.objects.create(**category.dict()) + return category.dict() + + +@app.put("/items/{item_id}") +async def get_item(item_id: int, item: Item): + item_db = await Item.objects.get(pk=item_id) + return {"updated_rows": await item_db.update(**item.dict())} + + +@app.delete("/items/{item_id}") +async def delete_item(item_id: int, item: Item): + item_db = await Item.objects.get(pk=item_id) + return {"deleted_rows": await item_db.delete()} + + +def test_all_endpoints(): + client = TestClient(app) + with client as client: + response = client.post("/categories/", json={"name": "test cat"}) + category = response.json() + response = client.post( + "/items/", json={"name": "test", "id": 1, "category": category} + ) + item = Item(**response.json()) + assert item.pk is not None + + response = client.get("/items/") + items = [Item(**item) for item in response.json()] + assert items[0] == item + + item.name = "New name" + response = client.put(f"/items/{item.pk}", json=item.dict()) + assert response.json().get("updated_rows") == 1 + + response = client.get("/items/") + items = [Item(**item) for item in response.json()] + assert items[0].name == "New name" + + response = client.delete(f"/items/{item.pk}", json=item.dict()) + assert response.json().get("deleted_rows") == 1 + response = client.get("/items/") + items = response.json() + assert len(items) == 0 diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 69f5dd2..60ddddc 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -61,7 +61,7 @@ class Teacher(ormar.Model): category = ormar.ForeignKey(Category, nullable=True) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def event_loop(): loop = asyncio.get_event_loop() yield loop @@ -92,6 +92,8 @@ async def test_model_multiple_instances_of_same_table_in_schema(): assert classes[0].name == "Math" assert classes[0].students[0].name == "Jane" + assert len(classes[0].dict().get("students")) == 2 + # related fields of main model are only populated by pk # unless there is a required foreign key somewhere along the way # since department is required for schoolclass it was pre loaded (again) From b69ad226e622fe18e980edad8ed3c532f8271f4a Mon Sep 17 00:00:00 2001 From: collerek Date: Sat, 15 Aug 2020 12:51:01 +0200 Subject: [PATCH 61/62] update config --- README.md | 7 ++++++- docs/index.md | 20 +++++++++++++------- ormar/__init__.py | 2 +- setup.py | 2 +- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index c144d4f..c3ae7af 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,11 @@ # ORMar -

+ + Pypi version + + + Pypi version + Build Status diff --git a/docs/index.md b/docs/index.md index cf66bba..bf035bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,16 +1,22 @@ # ORMar

- - Build Status + + Pypi version - - Coverage + + Pypi version - -CodeFactor + + Build Status - + + Coverage + + +CodeFactor + + Codacy

diff --git a/ormar/__init__.py b/ormar/__init__.py index 3d7e820..098adf2 100644 --- a/ormar/__init__.py +++ b/ormar/__init__.py @@ -15,7 +15,7 @@ from ormar.fields import ( ) from ormar.models import Model -__version__ = "0.1.1" +__version__ = "0.1.3" __all__ = [ "Integer", "BigInteger", diff --git a/setup.py b/setup.py index 4b88c2d..0f768da 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ setup( long_description=get_long_description(), long_description_content_type="text/markdown", keywords=['ORM', 'sqlalchemy', 'fastapi', 'pydantic', 'databases'], - author="Radosław Drążkiewicz", + author="collerek", author_email="collerek@gmail.com", packages=get_packages(PACKAGE), package_data={PACKAGE: ["py.typed"]}, From a39179bc645567fa447983ead4f0fae0ff867a5b Mon Sep 17 00:00:00 2001 From: collerek Date: Sun, 16 Aug 2020 22:27:39 +0200 Subject: [PATCH 62/62] mostly working --- .coverage | Bin 53248 -> 0 bytes ormar/fields/base.py | 108 ++++---- ormar/fields/foreign_key.py | 120 +++++---- ormar/fields/model_fields.py | 394 ++++++++++++++++++++++++---- ormar/models/fakepydantic.py | 182 +++++++++---- ormar/models/metaclass.py | 143 ++++++---- ormar/models/model.py | 62 ++--- ormar/queryset/clause.py | 18 +- ormar/queryset/query.py | 92 +++---- ormar/queryset/queryset.py | 33 +-- ormar/relations.py | 7 +- tests/test_columns.py | 22 +- tests/test_fastapi_usage.py | 24 +- tests/test_foreign_keys.py | 93 ++++--- tests/test_model_definition.py | 98 +++---- tests/test_models.py | 39 +-- tests/test_more_reallife_fastapi.py | 24 +- tests/test_same_table_joins.py | 65 ++--- 18 files changed, 988 insertions(+), 536 deletions(-) delete mode 100644 .coverage diff --git a/.coverage b/.coverage deleted file mode 100644 index bf048747f459554fda48d6da57377636b1799559..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53248 zcmeI4dvF`Y8Nly!vSht({g$1O=B5u!63emdBtXEJ(55Yq4&^aSr!hUw(#cUG=@i|` zjwgQ9JciCNo$0jGGHocdol=@Upe+L#Lco8tblNh(Qz(x(iQUp6rX{wggu#RmyWj5J zSyE!z`j3hEZlvA2z5RB-{q47V`gC`D+buW64ONP1nWSP!ZCnM%^W3$P#BrPlz5)2! zn;UkV_5s9v-hQ!N57)d&3XnTEzv~Aa`9|PtB;@9})k_n3Q^CjlpLLjp(u z2_OL^zz7WA;&%rd8u+^h45cff8W|;`X3Q|~`CGSdyM4QK`?l+D+Af)C(q&!=wopjg zCS|lt*GU;Qre@SsMAc0#E{jLkdDq=D^T=TtV?1;}@0lWE#HAYoQ)XvHLp$r_r~*pmBZk-p5ex1c+1ty_Mze z;M%qP{d<_9P&{&1+|UR%@fmf(}wwLhCBDHD)oo3Ihb@0y|qQn~Kn(qSH1+v>r91bgRwQ zS#w(WxD5$RC+u+rx34W`ItwNshZlCe!woKXaP?~bZVxjftGBGkU0_0XPcAZz+NYAL zqGvPqM6%Q*X-nj0NqdY~XP`MDw=1}9b+P7H`Q)&*Q$Fl=x`VZ~{IEG9=)6R~Wd)6( zDY+qV-i4ra_0x*6Q@T)hnhc6Jmbrpg)fQ8nsvw7H=(^`t2UGceH>00~NuPd&0+ne> zPURcbzEF>nfY~Ii#4|8tVGcI0020j@RUos{rDY9fzx0rqC%yLM^fe2}9oy_7*E*IR zhBs8CZ?zAk#+`~zCm+30>e94?s-(V28g(<`AsaMgZNJfQ|a0y-4C@j_cw3Fm8 z4XKx0PJhS@ujSpr4IB93tThJB=|$GzTAoyFwXH)(fo+_3GG`_?!#oQ=a`MfCvdhCk z56v3roM!f5mr23SW&OnQuA>eIE2+f>n8f<=GNx$a69i@bGK$c|FEO+hOM2Q z>5R5ZjTm}or=s_zA|ZIV9unKG=uuNWot=qzS7#?YpQnb>N@O=YZ3Ow!eTJ$Vot?6r zRN^UFh9?;so215&-qXY45IhdhPM^`wsO9Nj8G=t(&0GIlIDEnykNt-Qz5n0fb8+M!az77z<4GACtB!C2v01`j~ zNB{|3qy#Q+5V$Pdxm&n#m%DGr?y_~#P}boUZkfHvEAN6u8+5oycXhqM4J|8}gcSs2 zc#AlZh{seprs#%}j)Qb4C~!&%(or=*E4#d|6fKh)WS67xhJ1>x0Rn<;wE~x1wjOKw z4;fw+r?qXW5jbsmA+s@ZR;SA*r2E!a3*3&Sq;%Dw>#@)pTB-!Dd-1&NR%a)ytdZ#h z0dof=!j%FSFG(VjP+*xC5NWF@+V;gNf!e4LDOE_8DQ0%VwpkRo$WqF&p3ys51eI(E z2;7}Z<NsOu2Y%KK7rd)f=o==t)}~;O3H{wKrHNC z2C`LEUent=%kYZD)kIY9jM4>d6hq4}^~#q~&k|zQ$>ss5x5-`9F3k?L#PZA;q(ZKu zL#;rnOJP=djG${=jKQsxslY1{}I0=J}LH#HGva>1A&_YLH`r}0sotUe)6jS7V@Zn zweO_wx4uE&wIG5U5Eav~$ zHw#KB(w6@pEJ;h#T`KVZ>oy6=) z=9+^}Rjoq2G>HZOzp{AS7puhd|0_xs%5$a2Mj^7Ka<+?3|Db`!h)u$sOXkn_{{wBp z?j0uQe@8a|9#6rwyMf&x_3EVmb+Zw|9h5G&k`~X+B^Vk<;Cr?z?082 zXOMCiA8G|smjCZsj=_cge_2S7m!w_b|2vD16H}i&xpX zPaK@-@0~n&;#mLw(NS&i%v<|s_+TC6XsJ8&;{53R2M?W|IC=8Qkq+2~*4M&`J8G%u zc;D%luId<_d~tr@s`**%^b@_)$B*rs>z(|l_ubyfsnfi_2J)_|8UNAcGY=haaaF^S zSnd5x6~tez8hfF8n0of+@5_U&ZWo~y z5pSgk@n^+duZ_R|di#lUvx8@^sR}^SrGdd04$nV7H}%f6JTp;M?}xNjf6s&cFV21t zn*a0A{9EsyI(72-ukex&&RpmF6gU1r{WoWJ&wMz(<(Z-0`PZJxj-8!)Zt(1!$K{0_ zqPNcDfw;!AetIxaQ4YJGE`R!&qlqp1kHnrldc4XFDVMs(UU=lxQ19$?`}8yIq{2mW zxqkIV`@wzV1M?k69iB2;M48CV4SAe&*BOh&TwVco>jW2fZhmdR0sEIY4jnn%(J?Vo zR>nI$9?`+x|0lO`fCP{L5Yz5kE?|H|)Y zY#|ar0!RP}AOR$R1dsp{Kmter34F2%u=oGTJ@oni_sRdrDKbgkBq!h None: - self.name = None - self._populate_from_kwargs(kwargs) + column_type: sqlalchemy.Column + constraints: List = [] - def _populate_from_kwargs(self, kwargs: Dict) -> None: - self.primary_key = kwargs.pop("primary_key", False) - self.autoincrement = kwargs.pop( - "autoincrement", self.primary_key and self.__type__ == int - ) + primary_key: bool + autoincrement: bool + nullable: bool + index: bool + unique: bool + pydantic_only: bool - self.nullable = kwargs.pop("nullable", not self.primary_key) - self.default = kwargs.pop("default", None) - self.server_default = kwargs.pop("server_default", None) + default: Any + server_default: Any - self.index = kwargs.pop("index", None) - self.unique = kwargs.pop("unique", None) - - self.pydantic_only = kwargs.pop("pydantic_only", False) - if self.pydantic_only and self.primary_key: - raise ModelDefinitionError("Primary key column cannot be pydantic only.") - - @property - def is_required(self) -> bool: + @classmethod + def is_required(cls) -> bool: return ( - not self.nullable and not self.has_default and not self.is_auto_primary_key + not cls.nullable and not cls.has_default() and not cls.is_auto_primary_key() ) - @property - def default_value(self) -> Any: - default = self.default - return default() if callable(default) else default + @classmethod + def default_value(cls): + if cls.is_auto_primary_key(): + return Field(default=None) + if cls.has_default(): + default = cls.default if cls.default is not None else cls.server_default + if callable(default): + return Field(default_factory=default) + else: + return Field(default=default) + return None - @property - def has_default(self) -> bool: - return self.default is not None or self.server_default is not None + @classmethod + def has_default(cls): + return cls.default is not None or cls.server_default is not None - @property - def is_auto_primary_key(self) -> bool: - if self.primary_key: - return self.autoincrement + @classmethod + def is_auto_primary_key(cls) -> bool: + if cls.primary_key: + return cls.autoincrement return False - def get_column(self, name: str = None) -> sqlalchemy.Column: - self.name = name - constraints = self.get_constraints() + @classmethod + def get_column(cls, name: str) -> sqlalchemy.Column: return sqlalchemy.Column( - self.name, - self.get_column_type(), - *constraints, - primary_key=self.primary_key, - autoincrement=self.autoincrement, - nullable=self.nullable, - index=self.index, - unique=self.unique, - default=self.default, - server_default=self.server_default, + name, + cls.column_type, + *cls.constraints, + primary_key=cls.primary_key, + nullable=cls.nullable and not cls.primary_key, + index=cls.index, + unique=cls.unique, + default=cls.default, + server_default=cls.server_default, ) - def get_column_type(self) -> sqlalchemy.types.TypeEngine: - raise NotImplementedError() # pragma: no cover - - def get_constraints(self) -> Optional[List]: - return [] - - def expand_relationship(self, value: Any, child: "Model") -> Any: + @classmethod + def expand_relationship(cls, value: Any, child: "Model") -> Any: return value - - def __repr__(self): # pragma no cover - return str(self.__dict__) diff --git a/ormar/fields/foreign_key.py b/ormar/fields/foreign_key.py index 77c2da7..87a4f4d 100644 --- a/ormar/fields/foreign_key.py +++ b/ormar/fields/foreign_key.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, TYPE_CHECKING, Type, Union +from typing import Any, List, Optional, TYPE_CHECKING, Type, Union, Callable import sqlalchemy from pydantic import BaseModel @@ -13,87 +13,115 @@ if TYPE_CHECKING: # pragma no cover def create_dummy_instance(fk: Type["Model"], pk: Any = None) -> "Model": init_dict = { - **{fk.__pkname__: pk or -1}, + **{fk.Meta.pkname: pk or -1, + '__pk_only__': True}, **{ k: create_dummy_instance(v.to) - for k, v in fk.__model_fields__.items() - if isinstance(v, ForeignKey) and not v.nullable and not v.virtual + for k, v in fk.Meta.model_fields.items() + if isinstance(v, ForeignKeyField) and not v.nullable and not v.virtual }, } return fk(**init_dict) -class ForeignKey(BaseField): - def __init__( - self, - to: Type["Model"], - name: str = None, - related_name: str = None, - nullable: bool = True, - virtual: bool = False, - ) -> None: - super().__init__(nullable=nullable, name=name) - self.virtual = virtual - self.related_name = related_name - self.to = to +def ForeignKey(to, *, name: str = None, unique: bool = False, nullable: bool = True, + related_name: str = None, + virtual: bool = False, + ) -> Type[object]: + fk_string = to.Meta.tablename + "." + to.Meta.pkname + to_field = to.__fields__[to.Meta.pkname] + namespace = dict( + to=to, + name=name, + nullable=nullable, + constraints=[sqlalchemy.schema.ForeignKey(fk_string)], + unique=unique, + column_type=to_field.type_.column_type, + related_name=related_name, + virtual=virtual, + primary_key=False, + index=False, + pydantic_only=False, + default=None, + server_default=None + ) + + return type("ForeignKey", (ForeignKeyField, BaseField), namespace) + + +class ForeignKeyField(BaseField): + to: Type["Model"] + related_name: str + virtual: bool + + @classmethod + def __get_validators__(cls) -> Callable: + yield cls.validate + + @classmethod + def validate(cls, v: Any) -> Any: + return v @property def __type__(self) -> Type[BaseModel]: return self.to.__pydantic_model__ - def get_constraints(self) -> List[sqlalchemy.schema.ForeignKey]: - fk_string = self.to.__tablename__ + "." + self.to.__pkname__ - return [sqlalchemy.schema.ForeignKey(fk_string)] - - def get_column_type(self) -> sqlalchemy.Column: - to_column = self.to.__model_fields__[self.to.__pkname__] - return to_column.get_column_type() + @classmethod + def get_column_type(cls) -> sqlalchemy.Column: + to_column = cls.to.Meta.model_fields[cls.to.Meta.pkname] + return to_column.column_type + @classmethod def _extract_model_from_sequence( - self, value: List, child: "Model" + cls, value: List, child: "Model" ) -> Union["Model", List["Model"]]: - return [self.expand_relationship(val, child) for val in value] + return [cls.expand_relationship(val, child) for val in value] - def _register_existing_model(self, value: "Model", child: "Model") -> "Model": - self.register_relation(value, child) + @classmethod + def _register_existing_model(cls, value: "Model", child: "Model") -> "Model": + cls.register_relation(value, child) return value - def _construct_model_from_dict(self, value: dict, child: "Model") -> "Model": - model = self.to(**value) - self.register_relation(model, child) + @classmethod + def _construct_model_from_dict(cls, value: dict, child: "Model") -> "Model": + model = cls.to(**value) + cls.register_relation(model, child) return model - def _construct_model_from_pk(self, value: Any, child: "Model") -> "Model": - if not isinstance(value, self.to.pk_type()): + @classmethod + def _construct_model_from_pk(cls, value: Any, child: "Model") -> "Model": + if not isinstance(value, cls.to.pk_type()): raise RelationshipInstanceError( - f"Relationship error - ForeignKey {self.to.__name__} " - f"is of type {self.to.pk_type()} " + f"Relationship error - ForeignKey {cls.to.__name__} " + f"is of type {cls.to.pk_type()} " f"while {type(value)} passed as a parameter." ) - model = create_dummy_instance(fk=self.to, pk=value) - self.register_relation(model, child) + model = create_dummy_instance(fk=cls.to, pk=value) + cls.register_relation(model, child) return model - def register_relation(self, model: "Model", child: "Model") -> None: - child_model_name = self.related_name or child.get_name() - model._orm_relationship_manager.add_relation( - model, child, child_model_name, virtual=self.virtual + @classmethod + def register_relation(cls, model: "Model", child: "Model") -> None: + child_model_name = cls.related_name or child.get_name() + model.Meta._orm_relationship_manager.add_relation( + model, child, child_model_name, virtual=cls.virtual ) + @classmethod def expand_relationship( - self, value: Any, child: "Model" + cls, value: Any, child: "Model" ) -> Optional[Union["Model", List["Model"]]]: if value is None: return None constructors = { - f"{self.to.__name__}": self._register_existing_model, - "dict": self._construct_model_from_dict, - "list": self._extract_model_from_sequence, + f"{cls.to.__name__}": cls._register_existing_model, + "dict": cls._construct_model_from_dict, + "list": cls._extract_model_from_sequence, } model = constructors.get( - value.__class__.__name__, self._construct_model_from_pk + value.__class__.__name__, cls._construct_model_from_pk )(value, child) return model diff --git a/ormar/fields/model_fields.py b/ormar/fields/model_fields.py index 4f9be11..30946bf 100644 --- a/ormar/fields/model_fields.py +++ b/ormar/fields/model_fields.py @@ -1,87 +1,373 @@ import datetime import decimal +import re +from typing import Type, Any, Optional +import pydantic import sqlalchemy from pydantic import Json +from ormar import ModelDefinitionError from ormar.fields.base import BaseField # noqa I101 -from ormar.fields.decorators import RequiredParams -@RequiredParams("length") -class String(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.String(self.length) +def is_field_nullable(nullable: Optional[bool], default: Any, server_default: Any) -> bool: + if nullable is None: + return default is not None or server_default is not None + return False -class Integer(BaseField): - __type__ = int +def String( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + if max_length is None or max_length <= 0: + raise ModelDefinitionError(f'Parameter max_length is required for field String') - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Integer() + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + min_length=min_length, + max_length=max_length, + curtail_length=curtail_length, + regex=regex and re.compile(regex), + column_type=sqlalchemy.String(length=max_length), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + + return type("String", (pydantic.ConstrainedStr, BaseField), namespace) -class Text(BaseField): - __type__ = str - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Text() +def Integer( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Integer(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("Integer", (pydantic.ConstrainedInt, BaseField), namespace) -class Float(BaseField): - __type__ = float +def Text( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + allow_blank: bool = False, + strip_whitespace: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[str]: + namespace = dict( + __type__=str, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + allow_blank=allow_blank, + strip_whitespace=strip_whitespace, + column_type=sqlalchemy.Text(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Float() + return type("Text", (pydantic.ConstrainedStr, BaseField), namespace) -class Boolean(BaseField): - __type__ = bool - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Boolean() +def Float( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=float, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.Float(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Float", (pydantic.ConstrainedFloat, BaseField), namespace) -class DateTime(BaseField): - __type__ = datetime.datetime - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DateTime() +def Boolean( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[bool]: + namespace = dict( + __type__=bool, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Boolean(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Boolean", (int, BaseField), namespace) -class Date(BaseField): - __type__ = datetime.date - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Date() +def DateTime( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.datetime]: + namespace = dict( + __type__=datetime.datetime, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.DateTime(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("DateTime", (datetime.datetime, BaseField), namespace) -class Time(BaseField): - __type__ = datetime.time - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.Time() +def Date( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.date]: + namespace = dict( + __type__=datetime.date, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Date(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Date", (datetime.date, BaseField), namespace) -class JSON(BaseField): - __type__ = Json - - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.JSON() +def Time( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[datetime.time]: + namespace = dict( + __type__=datetime.time, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.Time(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Time", (datetime.time, BaseField), namespace) -class BigInteger(BaseField): - __type__ = int +def JSON( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[Json]: + namespace = dict( + __type__=pydantic.Json, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + column_type=sqlalchemy.JSON(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.BigInteger() + return type("JSON", (pydantic.Json, BaseField), namespace) -@RequiredParams("length", "precision") -class Decimal(BaseField): - __type__ = decimal.Decimal +def BigInteger( + *, + name: str = None, + primary_key: bool = False, + autoincrement: bool = None, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: int = None, + maximum: int = None, + multiple_of: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +) -> Type[int]: + namespace = dict( + __type__=int, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.BigInteger(), + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=autoincrement if autoincrement is not None else primary_key + ) + return type("BigInteger", (pydantic.ConstrainedInt, BaseField), namespace) - def get_column_type(self) -> sqlalchemy.Column: - return sqlalchemy.DECIMAL(self.length, self.precision) + +def Decimal( + *, + name: str = None, + primary_key: bool = False, + nullable: bool = None, + index: bool = False, + unique: bool = False, + minimum: float = None, + maximum: float = None, + multiple_of: int = None, + precision: int = None, + scale: int = None, + max_digits: int = None, + decimal_places: int = None, + pydantic_only: bool = False, + default: Any = None, + server_default: Any = None +): + if precision is None or precision < 0 or scale is None or scale < 0: + raise ModelDefinitionError(f'Parameters scale and precision are required for field Decimal') + + namespace = dict( + __type__=decimal.Decimal, + name=name, + primary_key=primary_key, + nullable=is_field_nullable(nullable, default, server_default), + index=index, + unique=unique, + ge=minimum, + le=maximum, + multiple_of=multiple_of, + column_type=sqlalchemy.types.DECIMAL(precision=precision, scale=scale), + precision=precision, + scale=scale, + max_digits=max_digits, + decimal_places=decimal_places, + pydantic_only=pydantic_only, + default=default, + server_default=server_default, + autoincrement=False + ) + return type("Decimal", (pydantic.ConstrainedDecimal, BaseField), namespace) diff --git a/ormar/models/fakepydantic.py b/ormar/models/fakepydantic.py index d5109b4..e31f3f7 100644 --- a/ormar/models/fakepydantic.py +++ b/ormar/models/fakepydantic.py @@ -11,7 +11,7 @@ from typing import ( TYPE_CHECKING, Type, TypeVar, - Union, + Union, AbstractSet, Mapping, ) import databases @@ -20,19 +20,27 @@ import sqlalchemy from pydantic import BaseModel import ormar # noqa I100 +from ormar import ForeignKey from ormar.fields import BaseField -from ormar.models.metaclass import ModelMetaclass +from ormar.fields.foreign_key import ForeignKeyField +from ormar.models.metaclass import ModelMetaclass, ModelMeta from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover from ormar.models.model import Model + IntStr = Union[int, str] + DictStrAny = Dict[str, Any] + AbstractSetIntStr = AbstractSet[IntStr] + MappingIntStrAny = Mapping[IntStr, Any] -class FakePydantic(list, metaclass=ModelMetaclass): + +class FakePydantic(pydantic.BaseModel, metaclass=ModelMetaclass): # FakePydantic inherits from list in order to be treated as # request.Body parameter in fastapi routes, # inheriting from pydantic.BaseModel causes metaclass conflicts - __abstract__ = True + __slots__ = ('_orm_id', '_orm_saved') + if TYPE_CHECKING: # pragma no cover __model_fields__: Dict[str, TypeVar[BaseField]] __table__: sqlalchemy.Table @@ -43,62 +51,88 @@ class FakePydantic(list, metaclass=ModelMetaclass): __metadata__: sqlalchemy.MetaData __database__: databases.Database _orm_relationship_manager: RelationshipManager + Meta: ModelMeta + # noinspection PyMissingConstructor def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._orm_id: str = uuid.uuid4().hex - self._orm_saved: bool = False - self.values: Optional[BaseModel] = None + object.__setattr__(self, "_orm_id", uuid.uuid4().hex) + object.__setattr__(self, "_orm_saved", False) + + pk_only = kwargs.pop("__pk_only__", False) if "pk" in kwargs: - kwargs[self.__pkname__] = kwargs.pop("pk") + kwargs[self.Meta.pkname] = kwargs.pop("pk") kwargs = { - k: self.__model_fields__[k].expand_relationship(v, self) + k: self.Meta.model_fields[k].expand_relationship(v, self) for k, v in kwargs.items() } - self.values = self.__pydantic_model__(**kwargs) + + values, fields_set, validation_error = pydantic.validate_model( + self, kwargs + ) + if validation_error and not pk_only: + raise validation_error + + object.__setattr__(self, '__dict__', values) + object.__setattr__(self, '__fields_set__', fields_set) + + # super().__init__(**kwargs) + # self.values = self.__pydantic_model__(**kwargs) def __del__(self) -> None: - self._orm_relationship_manager.deregister(self) + self.Meta._orm_relationship_manager.deregister(self) - def __setattr__(self, key: str, value: Any) -> None: - if key in self.__fields__: - value = self._convert_json(key, value, op="dumps") - value = self.__model_fields__[key].expand_relationship(value, self) + def __setattr__(self, name, value): + if name in self.__slots__: + object.__setattr__(self, name, value) + elif name == 'pk': + object.__setattr__(self, self.Meta.pkname, value) + relation_key = self.get_name(title=True) + "_" + name + if self.Meta._orm_relationship_manager.contains(relation_key, self): + self.Meta.model_fields[name].expand_relationship(value, self) + return + super().__setattr__(name, value) - relation_key = self.get_name(title=True) + "_" + key - if not self._orm_relationship_manager.contains(relation_key, self): - setattr(self.values, key, value) - else: - super().__setattr__(key, value) + def __getattr__(self, item): + relation_key = self.get_name(title=True) + "_" + item + if self.Meta._orm_relationship_manager.contains(relation_key, self): + return self.Meta._orm_relationship_manager.get(relation_key, self) - def __getattribute__(self, key: str) -> Any: - if key != "__fields__" and key in self.__fields__: - relation_key = self.get_name(title=True) + "_" + key - if self._orm_relationship_manager.contains(relation_key, self): - return self._orm_relationship_manager.get(relation_key, self) + # def __setattr__(self, key: str, value: Any) -> None: + # if key in ('_orm_id', '_orm_relationship_manager', '_orm_saved', 'objects', '__model_fields__'): + # return setattr(self, key, value) + # # elif key in self._extract_related_names(): + # # value = self._convert_json(key, value, op="dumps") + # # value = self.Meta.model_fields[key].expand_relationship(value, self) + # # relation_key = self.get_name(title=True) + "_" + key + # # if not self.Meta._orm_relationship_manager.contains(relation_key, self): + # # setattr(self.values, key, value) + # else: + # super().__setattr__(key, value) - item = getattr(self.values, key, None) - item = self._convert_json(key, item, op="loads") - return item - return super().__getattribute__(key) - - def __eq__(self, other: "Model") -> bool: - return self.values.dict() == other.values.dict() + # def __getattribute__(self, key: str) -> Any: + # if key != 'Meta' and key in self.Meta.model_fields: + # relation_key = self.get_name(title=True) + "_" + key + # if self.Meta._orm_relationship_manager.contains(relation_key, self): + # return self.Meta._orm_relationship_manager.get(relation_key, self) + # item = getattr(self.__fields__, key, None) + # item = self._convert_json(key, item, op="loads") + # return item + # return super().__getattribute__(key) def __same__(self, other: "Model") -> bool: if self.__class__ != other.__class__: # pragma no cover return False return self._orm_id == other._orm_id or ( - self.values is not None and other.values is not None and self.pk == other.pk + self.__dict__ is not None and other.__dict__ is not None and self.pk == other.pk ) - def __repr__(self) -> str: # pragma no cover - return self.values.__repr__() + # def __repr__(self) -> str: # pragma no cover + # return self.values.__repr__() - @classmethod - def __get_validators__(cls) -> Callable: # pragma no cover - yield cls.__pydantic_model__.validate + # @classmethod + # def __get_validators__(cls) -> Callable: # pragma no cover + # yield cls.__pydantic_model__.validate @classmethod def get_name(cls, title: bool = False, lower: bool = True) -> str: @@ -109,25 +143,57 @@ class FakePydantic(list, metaclass=ModelMetaclass): name = name.title() return name + @property + def pk(self) -> Any: + return getattr(self, self.Meta.pkname) + + @pk.setter + def pk(self, value: Any) -> None: + setattr(self, self.Meta.pkname, value) + @property def pk_column(self) -> sqlalchemy.Column: - return self.__table__.primary_key.columns.values()[0] + return self.Meta.table.primary_key.columns.values()[0] @classmethod def pk_type(cls) -> Any: - return cls.__model_fields__[cls.__pkname__].__type__ + return cls.Meta.model_fields[cls.Meta.pkname].__type__ - def dict(self, nested=False) -> Dict: # noqa: A003 - dict_instance = self.values.dict() + def dict( + self, + *, + include: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + exclude: Union['AbstractSetIntStr', 'MappingIntStrAny'] = None, + by_alias: bool = False, + skip_defaults: bool = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + nested: bool = False + ) -> 'DictStrAny': # noqa: A003 + print('callin super', self.__class__) + print('to exclude', self._exclude_related_names_not_required(nested)) + dict_instance = super().dict(include=include, + exclude=self._exclude_related_names_not_required(nested), + by_alias=by_alias, + skip_defaults=skip_defaults, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none) + print('after super') for field in self._extract_related_names(): + print(self.__class__, field, nested) nested_model = getattr(self, field) - if self.__model_fields__[field].virtual and nested: + + if self.Meta.model_fields[field].virtual and nested: continue if isinstance(nested_model, list) and not isinstance( - nested_model, ormar.Model + nested_model, ormar.Model ): + print('nested list') dict_instance[field] = [x.dict(nested=True) for x in nested_model] else: + print('instance') dict_instance[field] = ( nested_model.dict(nested=True) if nested_model is not None else {} ) @@ -155,7 +221,7 @@ class FakePydantic(list, metaclass=ModelMetaclass): return value def _is_conversion_to_json_needed(self, column_name: str) -> bool: - return self.__model_fields__.get(column_name).__type__ == pydantic.Json + return self.Meta.model_fields.get(column_name).__type__ == pydantic.Json def _extract_own_model_fields(self) -> Dict: related_names = self._extract_related_names() @@ -165,22 +231,32 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def _extract_related_names(cls) -> Set: related_names = set() - for name, field in cls.__fields__.items(): - if inspect.isclass(field.type_) and issubclass( - field.type_, pydantic.BaseModel + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass( + field, ForeignKeyField ): related_names.add(name) return related_names + @classmethod + def _exclude_related_names_not_required(cls, nested:bool=False) -> Set: + if nested: + return cls._extract_related_names() + related_names = set() + for name, field in cls.Meta.model_fields.items(): + if inspect.isclass(field) and issubclass(field, ForeignKeyField) and field.nullable: + related_names.add(name) + return related_names + def _extract_model_db_fields(self) -> Dict: self_fields = self._extract_own_model_fields() self_fields = { - k: v for k, v in self_fields.items() if k in self.__table__.columns + k: v for k, v in self_fields.items() if k in self.Meta.table.columns } for field in self._extract_related_names(): if getattr(self, field) is not None: self_fields[field] = getattr( - getattr(self, field), self.__model_fields__[field].to.__pkname__ + getattr(self, field), self.Meta.model_fields[field].to.Meta.pkname ) return self_fields @@ -196,9 +272,9 @@ class FakePydantic(list, metaclass=ModelMetaclass): @classmethod def merge_two_instances(cls, one: "Model", other: "Model") -> "Model": - for field in one.__model_fields__.keys(): + for field in one.Meta.model_fields.keys(): if isinstance(getattr(one, field), list) and not isinstance( - getattr(one, field), ormar.Model + getattr(one, field), ormar.Model ): setattr(other, field, getattr(one, field) + getattr(other, field)) elif isinstance(getattr(one, field), ormar.Model): diff --git a/ormar/models/metaclass.py b/ormar/models/metaclass.py index 7ea2fe7..76088cc 100644 --- a/ormar/models/metaclass.py +++ b/ormar/models/metaclass.py @@ -1,12 +1,15 @@ -import copy -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, Union +import databases +import pydantic import sqlalchemy -from pydantic import BaseConfig, create_model -from pydantic.fields import ModelField +from pydantic import BaseConfig, create_model, Extra +from pydantic.fields import ModelField, FieldInfo from ormar import ForeignKey, ModelDefinitionError # noqa I100 from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField +from ormar.queryset import QuerySet from ormar.relations import RelationshipManager if TYPE_CHECKING: # pragma no cover @@ -15,6 +18,17 @@ if TYPE_CHECKING: # pragma no cover relationship_manager = RelationshipManager() +class ModelMeta: + tablename: str + table: sqlalchemy.Table + metadata: sqlalchemy.MetaData + database: databases.Database + columns: List[sqlalchemy.Column] + pkname: str + model_fields: Dict[str, Union[BaseField, ForeignKey]] + _orm_relationship_manager: RelationshipManager + + def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple]: pydantic_fields = { field_name: ( @@ -29,9 +43,9 @@ def parse_pydantic_field_from_model_fields(object_dict: dict) -> Dict[str, Tuple def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> None: child_relation_name = ( - field.to.get_name(title=True) - + "_" - + (field.related_name or (name.lower() + "s")) + field.to.get_name(title=True) + + "_" + + (field.related_name or (name.lower() + "s")) ) reverse_name = child_relation_name relation_name = name.lower().title() + "_" + field.to.get_name() @@ -41,104 +55,125 @@ def register_relation_on_build(table_name: str, field: ForeignKey, name: str) -> def expand_reverse_relationships(model: Type["Model"]) -> None: - for model_field in model.__model_fields__.values(): - if isinstance(model_field, ForeignKey): + for model_field in model.Meta.model_fields.values(): + if issubclass(model_field, ForeignKeyField): child_model_name = model_field.related_name or model.get_name() + "s" parent_model = model_field.to child = model if ( - child_model_name not in parent_model.__fields__ - and child.get_name() not in parent_model.__fields__ + child_model_name not in parent_model.__fields__ + and child.get_name() not in parent_model.__fields__ ): register_reverse_model_fields(parent_model, child, child_model_name) def register_reverse_model_fields( - model: Type["Model"], child: Type["Model"], child_model_name: str + model: Type["Model"], child: Type["Model"], child_model_name: str ) -> None: - model.__fields__[child_model_name] = ModelField( - name=child_model_name, - type_=Optional[child.__pydantic_model__], - model_config=child.__pydantic_model__.__config__, - class_validators=child.__pydantic_model__.__validators__, - ) - model.__model_fields__[child_model_name] = ForeignKey( + # model.__fields__[child_model_name] = ModelField( + # name=child_model_name, + # type_=Optional[Union[List[child], child]], + # model_config=child.__config__, + # class_validators=child.__validators__, + # ) + model.Meta.model_fields[child_model_name] = ForeignKey( child, name=child_model_name, virtual=True ) def sqlalchemy_columns_from_model_fields( - name: str, object_dict: Dict, table_name: str + name: str, object_dict: Dict, table_name: str ) -> Tuple[Optional[str], List[sqlalchemy.Column], Dict[str, BaseField]]: columns = [] pkname = None model_fields = { field_name: field - for field_name, field in object_dict.items() - if isinstance(field, BaseField) + for field_name, field in object_dict['__annotations__'].items() + if issubclass(field, BaseField) } for field_name, field in model_fields.items(): if field.primary_key: if pkname is not None: raise ModelDefinitionError("Only one primary key column is allowed.") + if field.pydantic_only: + raise ModelDefinitionError('Primary key column cannot be pydantic only') pkname = field_name if not field.pydantic_only: columns.append(field.get_column(field_name)) - if isinstance(field, ForeignKey): + if issubclass(field, ForeignKeyField): register_relation_on_build(table_name, field, name) return pkname, columns, model_fields +def populate_pydantic_default_values(attrs: Dict) -> Dict: + for field, type_ in attrs['__annotations__'].items(): + if issubclass(type_, BaseField): + if type_.name is None: + type_.name = field + def_value = type_.default_value() + curr_def_value = attrs.get(field, 'NONE') + if curr_def_value == 'NONE' and isinstance(def_value, FieldInfo): + attrs[field] = def_value + elif curr_def_value == 'NONE' and type_.nullable: + attrs[field] = FieldInfo(default=None) + return attrs + + def get_pydantic_base_orm_config() -> Type[BaseConfig]: class Config(BaseConfig): orm_mode = True + arbitrary_types_allowed = True + # extra = Extra.allow return Config -class ModelMetaclass(type): +class ModelMetaclass(pydantic.main.ModelMetaclass): def __new__(mcs: type, name: str, bases: Any, attrs: dict) -> type: + + attrs['Config'] = get_pydantic_base_orm_config() new_model = super().__new__( # type: ignore mcs, name, bases, attrs ) - if attrs.get("__abstract__"): - return new_model + if hasattr(new_model, 'Meta'): - tablename = attrs.get("__tablename__", name.lower() + "s") - attrs["__tablename__"] = tablename - metadata = attrs["__metadata__"] + if attrs.get("__abstract__"): + return new_model - # sqlalchemy table creation - pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( - name, attrs, tablename - ) - attrs["__table__"] = sqlalchemy.Table(tablename, metadata, *columns) - attrs["__columns__"] = columns - attrs["__pkname__"] = pkname + attrs = populate_pydantic_default_values(attrs) - if not pkname: - raise ModelDefinitionError("Table has to have a primary key.") + tablename = name.lower() + "s" + new_model.Meta.tablename = new_model.Meta.tablename or tablename - # pydantic model creation - pydantic_fields = parse_pydantic_field_from_model_fields(attrs) - pydantic_model = create_model( - name, __config__=get_pydantic_base_orm_config(), **pydantic_fields - ) - attrs["__pydantic_fields__"] = pydantic_fields - attrs["__pydantic_model__"] = pydantic_model - attrs["__fields__"] = copy.deepcopy(pydantic_model.__fields__) - attrs["__signature__"] = copy.deepcopy(pydantic_model.__signature__) - attrs["__annotations__"] = copy.deepcopy(pydantic_model.__annotations__) + # sqlalchemy table creation + pkname, columns, model_fields = sqlalchemy_columns_from_model_fields( + name, attrs, new_model.Meta.tablename + ) + new_model.Meta.table = sqlalchemy.Table(new_model.Meta.tablename, new_model.Meta.metadata, *columns) + new_model.Meta.columns = columns + new_model.Meta.pkname = pkname - attrs["__model_fields__"] = model_fields - attrs["_orm_relationship_manager"] = relationship_manager + if not pkname: + breakpoint() + raise ModelDefinitionError("Table has to have a primary key.") - new_model = super().__new__( # type: ignore - mcs, name, bases, attrs - ) + # pydantic model creation + new_model.Meta.pydantic_fields = parse_pydantic_field_from_model_fields(attrs) + new_model.Meta.pydantic_model = create_model( + name, __config__=get_pydantic_base_orm_config(), **new_model.Meta.pydantic_fields + ) - expand_reverse_relationships(new_model) + new_model.Meta.model_fields = model_fields + new_model = super().__new__( # type: ignore + mcs, name, bases, attrs + ) + expand_reverse_relationships(new_model) + + new_model.Meta._orm_relationship_manager = relationship_manager + new_model.objects = QuerySet(new_model) + + # breakpoint() return new_model diff --git a/ormar/models/model.py b/ormar/models/model.py index b16d3e7..4651aa1 100644 --- a/ormar/models/model.py +++ b/ormar/models/model.py @@ -7,39 +7,39 @@ from ormar.models import FakePydantic # noqa I100 class Model(FakePydantic): - __abstract__ = True + __abstract__ = False - objects = ormar.queryset.QuerySet() + # objects = ormar.queryset.QuerySet() @classmethod def from_row( - cls, - row: sqlalchemy.engine.ResultProxy, - select_related: List = None, - previous_table: str = None, + cls, + row: sqlalchemy.engine.ResultProxy, + select_related: List = None, + previous_table: str = None, ) -> "Model": item = {} select_related = select_related or [] - table_prefix = cls._orm_relationship_manager.resolve_relation_join( - previous_table, cls.__table__.name + table_prefix = cls.Meta._orm_relationship_manager.resolve_relation_join( + previous_table, cls.Meta.table.name ) - previous_table = cls.__table__.name + previous_table = cls.Meta.table.name for related in select_related: if "__" in related: first_part, remainder = related.split("__", 1) - model_cls = cls.__model_fields__[first_part].to + model_cls = cls.Meta.model_fields[first_part].to child = model_cls.from_row( row, select_related=[remainder], previous_table=previous_table ) item[first_part] = child else: - model_cls = cls.__model_fields__[related].to + model_cls = cls.Meta.model_fields[related].to child = model_cls.from_row(row, previous_table=previous_table) item[related] = child - for column in cls.__table__.columns: + for column in cls.Meta.table.columns: if column.name not in item: item[column.name] = row[ f'{table_prefix + "_" if table_prefix else ""}{column.name}' @@ -47,22 +47,14 @@ class Model(FakePydantic): return cls(**item) - @property - def pk(self) -> str: - return getattr(self.values, self.__pkname__) - - @pk.setter - def pk(self, value: Any) -> None: - setattr(self.values, self.__pkname__, value) - async def save(self) -> "Model": self_fields = self._extract_model_db_fields() - if self.__model_fields__.get(self.__pkname__).autoincrement: - self_fields.pop(self.__pkname__, None) - expr = self.__table__.insert() + if self.Meta.model_fields.get(self.Meta.pkname).autoincrement: + self_fields.pop(self.Meta.pkname, None) + expr = self.Meta.table.insert() expr = expr.values(**self_fields) - item_id = await self.__database__.execute(expr) - self.pk = item_id + item_id = await self.Meta.database.execute(expr) + setattr(self, self.Meta.pkname, item_id) return self async def update(self, **kwargs: Any) -> int: @@ -71,23 +63,23 @@ class Model(FakePydantic): self.from_dict(new_values) self_fields = self._extract_model_db_fields() - self_fields.pop(self.__pkname__) + self_fields.pop(self.Meta.pkname) expr = ( - self.__table__.update() - .values(**self_fields) - .where(self.pk_column == getattr(self, self.__pkname__)) + self.Meta.table.update() + .values(**self_fields) + .where(self.pk_column == getattr(self, self.Meta.pkname)) ) - result = await self.__database__.execute(expr) + result = await self.Meta.database.execute(expr) return result async def delete(self) -> int: - expr = self.__table__.delete() - expr = expr.where(self.pk_column == (getattr(self, self.__pkname__))) - result = await self.__database__.execute(expr) + expr = self.Meta.table.delete() + expr = expr.where(self.pk_column == (getattr(self, self.Meta.pkname))) + result = await self.Meta.database.execute(expr) return result async def load(self) -> "Model": - expr = self.__table__.select().where(self.pk_column == self.pk) - row = await self.__database__.fetch_one(expr) + expr = self.Meta.table.select().where(self.pk_column == self.pk) + row = await self.Meta.database.fetch_one(expr) self.from_dict(dict(row)) return self diff --git a/ormar/queryset/clause.py b/ormar/queryset/clause.py index 3d5d14b..dc94e6f 100644 --- a/ormar/queryset/clause.py +++ b/ormar/queryset/clause.py @@ -32,7 +32,7 @@ class QueryClause: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table def filter( # noqa: A003 self, **kwargs: Any @@ -41,7 +41,7 @@ class QueryClause: select_related = list(self._select_related) if kwargs.get("pk"): - pk_name = self.model_cls.__pkname__ + pk_name = self.model_cls.Meta.pkname kwargs[pk_name] = kwargs.pop("pk") for key, value in kwargs.items(): @@ -65,8 +65,8 @@ class QueryClause: related_parts, select_related ) - table = model_cls.__table__ - column = model_cls.__table__.columns[field_name] + table = model_cls.Meta.table + column = model_cls.Meta.table.columns[field_name] else: op = "exact" @@ -106,12 +106,12 @@ class QueryClause: # Walk the relationships to the actual model class # against which the comparison is being made. - previous_table = model_cls.__tablename__ + previous_table = model_cls.Meta.tablename for part in related_parts: - current_table = model_cls.__model_fields__[part].to.__tablename__ - manager = model_cls._orm_relationship_manager + current_table = model_cls.Meta.model_fields[part].to.Meta.tablename + manager = model_cls.Meta._orm_relationship_manager table_prefix = manager.resolve_relation_join(previous_table, current_table) - model_cls = model_cls.__model_fields__[part].to + model_cls = model_cls.Meta.model_fields[part].to previous_table = current_table return select_related, table_prefix, model_cls @@ -128,7 +128,7 @@ class QueryClause: clause_text = str( clause.compile( - dialect=self.model_cls.__database__._backend._dialect, + dialect=self.model_cls.Meta.database._backend._dialect, compile_kwargs={"literal_binds": True}, ) ) diff --git a/ormar/queryset/query.py b/ormar/queryset/query.py index 202b249..798502a 100644 --- a/ormar/queryset/query.py +++ b/ormar/queryset/query.py @@ -4,8 +4,8 @@ import sqlalchemy from sqlalchemy import text import ormar # noqa I100 -from ormar import ForeignKey from ormar.fields import BaseField +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar import Model @@ -20,12 +20,12 @@ class JoinParameters(NamedTuple): class Query: def __init__( - self, - model_cls: Type["Model"], - filter_clauses: List, - select_related: List, - limit_count: int, - offset: int, + self, + model_cls: Type["Model"], + filter_clauses: List, + select_related: List, + limit_count: int, + offset: int, ) -> None: self.query_offset = offset @@ -34,7 +34,7 @@ class Query: self.filter_clauses = filter_clauses self.model_cls = model_cls - self.table = self.model_cls.__table__ + self.table = self.model_cls.Meta.table self.auto_related = [] self.used_aliases = [] @@ -46,16 +46,16 @@ class Query: def build_select_expression(self) -> Tuple[sqlalchemy.sql.select, List[str]]: self.columns = list(self.table.columns) - self.order_bys = [text(f"{self.table.name}.{self.model_cls.__pkname__}")] + self.order_bys = [text(f"{self.table.name}.{self.model_cls.Meta.pkname}")] self.select_from = self.table - for key in self.model_cls.__model_fields__: + for key in self.model_cls.Meta.model_fields: if ( - not self.model_cls.__model_fields__[key].nullable - and isinstance( - self.model_cls.__model_fields__[key], ormar.fields.ForeignKey, - ) - and key not in self._select_related + not self.model_cls.Meta.model_fields[key].nullable + and isinstance( + self.model_cls.Meta.model_fields[key], ForeignKeyField, + ) + and key not in self._select_related ): self._select_related = [key] + self._select_related @@ -79,7 +79,7 @@ class Query: expr = self._apply_expression_modifiers(expr) - # print(expr.compile(compile_kwargs={"literal_binds": True})) + print(expr.compile(compile_kwargs={"literal_binds": True})) self._reset_query_parameters() return expr, self._select_related @@ -97,12 +97,12 @@ class Query: @staticmethod def _field_is_a_foreign_key_and_no_circular_reference( - field: BaseField, field_name: str, rel_part: str + field: BaseField, field_name: str, rel_part: str ) -> bool: - return isinstance(field, ForeignKey) and field_name not in rel_part + return issubclass(field, ForeignKeyField) and field_name not in rel_part def _field_qualifies_to_deeper_search( - self, field: ForeignKey, parent_virtual: bool, nested: bool, rel_part: str + self, field: ForeignKeyField, parent_virtual: bool, nested: bool, rel_part: str ) -> bool: prev_part_of_related = "__".join(rel_part.split("__")[:-1]) partial_match = any( @@ -112,39 +112,39 @@ class Query: [x.startswith(rel_part) for x in (self.auto_related + self.already_checked)] ) return ( - (field.virtual and parent_virtual) - or (partial_match and not already_checked) - ) or not nested + (field.virtual and parent_virtual) + or (partial_match and not already_checked) + ) or not nested def on_clause( - self, previous_alias: str, alias: str, from_clause: str, to_clause: str, + self, previous_alias: str, alias: str, from_clause: str, to_clause: str, ) -> text: left_part = f"{alias}_{to_clause}" right_part = f"{previous_alias + '_' if previous_alias else ''}{from_clause}" return text(f"{left_part}={right_part}") def _build_join_parameters( - self, part: str, join_params: JoinParameters + self, part: str, join_params: JoinParameters ) -> JoinParameters: - model_cls = join_params.model_cls.__model_fields__[part].to - to_table = model_cls.__table__.name + model_cls = join_params.model_cls.Meta.model_fields[part].to + to_table = model_cls.Meta.table.name - alias = model_cls._orm_relationship_manager.resolve_relation_join( + alias = model_cls.Meta._orm_relationship_manager.resolve_relation_join( join_params.from_table, to_table ) if alias not in self.used_aliases: - if join_params.prev_model.__model_fields__[part].virtual: + if join_params.prev_model.Meta.model_fields[part].virtual: to_key = next( ( v - for k, v in model_cls.__model_fields__.items() - if isinstance(v, ForeignKey) and v.to == join_params.prev_model + for k, v in model_cls.Meta.model_fields.items() + if issubclass(v, ForeignKeyField) and v.to == join_params.prev_model ), None, ).name - from_key = model_cls.__pkname__ + from_key = model_cls.Meta.pkname else: - to_key = model_cls.__pkname__ + to_key = model_cls.Meta.pkname from_key = part on_clause = self.on_clause( @@ -157,8 +157,8 @@ class Query: self.select_from = sqlalchemy.sql.outerjoin( self.select_from, target_table, on_clause ) - self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.__pkname__}")) - self.columns.extend(self.prefixed_columns(alias, model_cls.__table__)) + self.order_bys.append(text(f"{alias}_{to_table}.{model_cls.Meta.pkname}")) + self.columns.extend(self.prefixed_columns(alias, model_cls.Meta.table)) self.used_aliases.append(alias) previous_alias = alias @@ -167,24 +167,28 @@ class Query: return JoinParameters(prev_model, previous_alias, from_table, model_cls) def _extract_auto_required_relations( - self, - prev_model: Type["Model"], - rel_part: str = "", - nested: bool = False, - parent_virtual: bool = False, + self, + prev_model: Type["Model"], + rel_part: str = "", + nested: bool = False, + parent_virtual: bool = False, ) -> None: - for field_name, field in prev_model.__model_fields__.items(): + for field_name, field in prev_model.Meta.model_fields.items(): if self._field_is_a_foreign_key_and_no_circular_reference( - field, field_name, rel_part + field, field_name, rel_part ): rel_part = field_name if not rel_part else rel_part + "__" + field_name if not field.nullable: + print('add', rel_part, field) if rel_part not in self._select_related: - self.auto_related.append("__".join(rel_part.split("__")[:-1])) + new_related = "__".join(rel_part.split("__")[:-1]) if len( + rel_part.split("__")) > 1 else rel_part + self.auto_related.append(new_related) rel_part = "" elif self._field_qualifies_to_deeper_search( - field, parent_virtual, nested, rel_part + field, parent_virtual, nested, rel_part ): + print('deeper', rel_part, field, field.to) self._extract_auto_required_relations( prev_model=field.to, rel_part=rel_part, @@ -204,7 +208,7 @@ class Query: self._select_related = new_joins + self.auto_related def _apply_expression_modifiers( - self, expr: sqlalchemy.sql.select + self, expr: sqlalchemy.sql.select ) -> sqlalchemy.sql.select: if self.filter_clauses: if len(self.filter_clauses) == 1: diff --git a/ormar/queryset/queryset.py b/ormar/queryset/queryset.py index 65192b8..4e5d85e 100644 --- a/ormar/queryset/queryset.py +++ b/ormar/queryset/queryset.py @@ -14,12 +14,12 @@ if TYPE_CHECKING: # pragma no cover class QuerySet: def __init__( - self, - model_cls: Type["Model"] = None, - filter_clauses: List = None, - select_related: List = None, - limit_count: int = None, - offset: int = None, + self, + model_cls: Type["Model"] = None, + filter_clauses: List = None, + select_related: List = None, + limit_count: int = None, + offset: int = None, ) -> None: self.model_cls = model_cls self.filter_clauses = [] if filter_clauses is None else filter_clauses @@ -33,11 +33,11 @@ class QuerySet: @property def database(self) -> databases.Database: - return self.model_cls.__database__ + return self.model_cls.Meta.database @property def table(self) -> sqlalchemy.Table: - return self.model_cls.__table__ + return self.model_cls.Meta.table def build_select_expression(self) -> sqlalchemy.sql.select: qry = Query( @@ -148,12 +148,12 @@ class QuerySet: new_kwargs = dict(**kwargs) # Remove primary key when None to prevent not null constraint in postgresql. - pkname = self.model_cls.__pkname__ - pk = self.model_cls.__model_fields__[pkname] + pkname = self.model_cls.Meta.pkname + pk = self.model_cls.Meta.model_fields[pkname] if ( - pkname in new_kwargs - and new_kwargs.get(pkname) is None - and (pk.nullable or pk.autoincrement) + pkname in new_kwargs + and new_kwargs.get(pkname) is None + and (pk.nullable or pk.autoincrement) ): del new_kwargs[pkname] @@ -163,11 +163,11 @@ class QuerySet: if isinstance(new_kwargs.get(field), ormar.Model): new_kwargs[field] = getattr( new_kwargs.get(field), - self.model_cls.__model_fields__[field].to.__pkname__, + self.model_cls.Meta.model_fields[field].to.Meta.pkname, ) else: new_kwargs[field] = new_kwargs.get(field).get( - self.model_cls.__model_fields__[field].to.__pkname__ + self.model_cls.Meta.model_fields[field].to.Meta.pkname ) # Build the insert expression. @@ -176,5 +176,6 @@ class QuerySet: # Execute the insert, and return a new model instance. instance = self.model_cls(**kwargs) - instance.pk = await self.database.execute(expr) + pk = await self.database.execute(expr) + setattr(instance, self.model_cls.Meta.pkname, pk) return instance diff --git a/ormar/relations.py b/ormar/relations.py index 45c3ddd..8a422d9 100644 --- a/ormar/relations.py +++ b/ormar/relations.py @@ -6,6 +6,7 @@ from typing import List, TYPE_CHECKING, Union from weakref import proxy from ormar import ForeignKey +from ormar.fields.foreign_key import ForeignKeyField if TYPE_CHECKING: # pragma no cover from ormar.models import FakePydantic, Model @@ -21,14 +22,14 @@ class RelationshipManager: self._aliases = dict() def add_relation_type( - self, relations_key: str, reverse_key: str, field: ForeignKey, table_name: str + self, relations_key: str, reverse_key: str, field: ForeignKeyField, table_name: str ) -> None: if relations_key not in self._relations: self._relations[relations_key] = {"type": "primary"} - self._aliases[f"{table_name}_{field.to.__tablename__}"] = get_table_alias() + self._aliases[f"{table_name}_{field.to.Meta.tablename}"] = get_table_alias() if reverse_key not in self._relations: self._relations[reverse_key] = {"type": "reverse"} - self._aliases[f"{field.to.__tablename__}_{table_name}"] = get_table_alias() + self._aliases[f"{field.to.Meta.tablename}_{table_name}"] = get_table_alias() def deregister(self, model: "FakePydantic") -> None: for rel_type in self._relations.keys(): diff --git a/tests/test_columns.py b/tests/test_columns.py index e75cb0a..c8c9d3b 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -16,17 +16,19 @@ def time(): class Example(ormar.Model): - __tablename__ = "example" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "example" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - created = ormar.DateTime(default=datetime.datetime.now) - created_day = ormar.Date(default=datetime.date.today) - created_time = ormar.Time(default=time) - description = ormar.Text(nullable=True) - value = ormar.Float(nullable=True) - data = ormar.JSON(default={}) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=200, default='aaa') + created: ormar.DateTime(default=datetime.datetime.now) + created_day: ormar.Date(default=datetime.date.today) + created_time: ormar.Time(default=time) + description: ormar.Text(nullable=True) + value: ormar.Float(nullable=True) + data: ormar.JSON(default={}) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_fastapi_usage.py b/tests/test_fastapi_usage.py index 25b1310..f7f2625 100644 --- a/tests/test_fastapi_usage.py +++ b/tests/test_fastapi_usage.py @@ -13,22 +13,24 @@ metadata = sqlalchemy.MetaData() class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @app.post("/items/", response_model=Item) diff --git a/tests/test_foreign_keys.py b/tests/test_foreign_keys.py index bffa783..17e4a52 100644 --- a/tests/test_foreign_keys.py +++ b/tests/test_foreign_keys.py @@ -1,6 +1,7 @@ import databases import pytest import sqlalchemy +from pydantic import ValidationError import ormar from ormar.exceptions import NoMatch, MultipleMatches, RelationshipInstanceError @@ -11,62 +12,68 @@ metadata = sqlalchemy.MetaData() class Album(ormar.Model): - __tablename__ = "album" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "albums" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Track(ormar.Model): - __tablename__ = "track" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "tracks" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album) - title = ormar.String(length=100) - position = ormar.Integer() + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album) + title: ormar.String(max_length=100) + position: ormar.Integer() class Cover(ormar.Model): - __tablename__ = "covers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "covers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - album = ormar.ForeignKey(Album, related_name="cover_pictures") - title = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + album: ormar.ForeignKey(Album, related_name="cover_pictures") + title: ormar.String(max_length=100) class Organisation(ormar.Model): - __tablename__ = "org" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "org" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - ident = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + ident: ormar.String(max_length=100) class Team(ormar.Model): - __tablename__ = "team" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teams" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - org = ormar.ForeignKey(Organisation) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + org: ormar.ForeignKey(Organisation) + name: ormar.String(max_length=100) class Member(ormar.Model): - __tablename__ = "member" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "members" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - team = ormar.ForeignKey(Team) - email = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + team: ormar.ForeignKey(Team) + email: ormar.String(max_length=100) @pytest.fixture(autouse=True, scope="module") @@ -171,8 +178,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name="Fantasies") - .all() + .filter(album__name="Fantasies") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -180,8 +187,8 @@ async def test_fk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(album__name__icontains="fan") - .all() + .filter(album__name__icontains="fan") + .all() ) assert len(tracks) == 3 for track in tracks: @@ -223,8 +230,8 @@ async def test_multiple_fk(): members = ( await Member.objects.select_related("team__org") - .filter(team__org__ident="ACME Ltd") - .all() + .filter(team__org__ident="ACME Ltd") + .all() ) assert len(members) == 4 for member in members: @@ -243,8 +250,8 @@ async def test_pk_filter(): tracks = ( await Track.objects.select_related("album") - .filter(position=2, album__name="Test") - .all() + .filter(position=2, album__name="Test") + .all() ) assert len(tracks) == 1 diff --git a/tests/test_model_definition.py b/tests/test_model_definition.py index f6dc722..adf20ad 100644 --- a/tests/test_model_definition.py +++ b/tests/test_model_definition.py @@ -1,4 +1,5 @@ import datetime +import decimal import pydantic import pytest @@ -12,19 +13,21 @@ metadata = sqlalchemy.MetaData() class ExampleModel(Model): - __tablename__ = "example" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) - test_text = fields.Text(default="") - test_bool = fields.Boolean(nullable=False) - test_float = fields.Float() - test_datetime = fields.DateTime(default=datetime.datetime.now) - test_date = fields.Date(default=datetime.date.today) - test_time = fields.Time(default=datetime.time) - test_json = fields.JSON(default={}) - test_bigint = fields.BigInteger(default=0) - test_decimal = fields.Decimal(length=10, precision=2) + class Meta: + tablename = "example" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) + test_text: fields.Text(default="") + test_bool: fields.Boolean(nullable=False) + test_float: fields.Float() = None + test_datetime: fields.DateTime(default=datetime.datetime.now) + test_date: fields.Date(default=datetime.date.today) + test_time: fields.Time(default=datetime.time) + test_json: fields.JSON(default={}) + test_bigint: fields.BigInteger(default=0) + test_decimal: fields.Decimal(scale=10, precision=2) fields_to_check = [ @@ -41,15 +44,17 @@ fields_to_check = [ class ExampleModel2(Model): - __tablename__ = "example2" - __metadata__ = metadata - test = fields.Integer(primary_key=True) - test_string = fields.String(length=250) + class Meta: + tablename = "example2" + metadata = metadata + + test: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250) @pytest.fixture() def example(): - return ExampleModel(pk=1, test_string="test", test_bool=True) + return ExampleModel(pk=1, test_string="test", test_bool=True, test_decimal=decimal.Decimal(3.5)) def test_not_nullable_field_is_required(): @@ -83,60 +88,65 @@ def test_primary_key_access_and_setting(example): def test_pydantic_model_is_created(example): - assert issubclass(example.values.__class__, pydantic.BaseModel) - assert all([field in example.values.__fields__ for field in fields_to_check]) - assert example.values.test == 1 + assert issubclass(example.__class__, pydantic.BaseModel) + assert all([field in example.__fields__ for field in fields_to_check]) + assert example.test == 1 def test_sqlalchemy_table_is_created(example): - assert issubclass(example.__table__.__class__, sqlalchemy.Table) - assert all([field in example.__table__.columns for field in fields_to_check]) + assert issubclass(example.Meta.table.__class__, sqlalchemy.Table) + assert all([field in example.Meta.table.columns for field in fields_to_check]) def test_no_pk_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - test_string = fields.String(length=250) + class Meta: + tablename = "example3" + metadata = metadata + + test_string: fields.String(max_length=250) def test_two_pks_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example3" - __metadata__ = metadata - id = fields.Integer(primary_key=True) - test_string = fields.String(length=250, primary_key=True) + class Meta: + tablename = "example3" + metadata = metadata + + id: fields.Integer(primary_key=True) + test_string: fields.String(max_length=250, primary_key=True) def test_setting_pk_column_as_pydantic_only_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Integer(primary_key=True, pydantic_only=True) + class Meta: + tablename = "example4" + metadata = metadata + + test: fields.Integer(primary_key=True, pydantic_only=True) def test_decimal_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.Decimal(primary_key=True) + class Meta: + tablename = "example5" + metadata = metadata + + test: fields.Decimal(primary_key=True) def test_string_error_in_model_definition(): with pytest.raises(ModelDefinitionError): - class ExampleModel2(Model): - __tablename__ = "example4" - __metadata__ = metadata - test = fields.String(primary_key=True) + class Meta: + tablename = "example6" + metadata = metadata + + test: fields.String(primary_key=True) def test_json_conversion_in_model(): diff --git a/tests/test_models.py b/tests/test_models.py index 69d421c..42c6b54 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ import databases +import pydantic import pytest import sqlalchemy @@ -11,23 +12,25 @@ metadata = sqlalchemy.MetaData() class User(ormar.Model): - __tablename__ = "users" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "users" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100, default='') class Product(ormar.Model): - __tablename__ = "product" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "product" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - rating = ormar.Integer(minimum=1, maximum=5) - in_stock = ormar.Boolean(default=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + rating: ormar.Integer(minimum=1, maximum=5) + in_stock: ormar.Boolean(default=False) @pytest.fixture(autouse=True, scope="module") @@ -39,12 +42,12 @@ def create_test_database(): def test_model_class(): - assert list(User.__model_fields__.keys()) == ["id", "name"] - assert isinstance(User.__model_fields__["id"], ormar.Integer) - assert User.__model_fields__["id"].primary_key is True - assert isinstance(User.__model_fields__["name"], ormar.String) - assert User.__model_fields__["name"].length == 100 - assert isinstance(User.__table__, sqlalchemy.Table) + assert list(User.Meta.model_fields.keys()) == ["id", "name"] + assert issubclass(User.Meta.model_fields["id"], pydantic.ConstrainedInt) + assert User.Meta.model_fields["id"].primary_key is True + assert issubclass(User.Meta.model_fields["name"], pydantic.ConstrainedStr) + assert User.Meta.model_fields["name"].max_length == 100 + assert isinstance(User.Meta.table, sqlalchemy.Table) def test_model_pk(): diff --git a/tests/test_more_reallife_fastapi.py b/tests/test_more_reallife_fastapi.py index d7670cc..53ff9d4 100644 --- a/tests/test_more_reallife_fastapi.py +++ b/tests/test_more_reallife_fastapi.py @@ -30,22 +30,24 @@ async def shutdown() -> None: class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Item(ormar.Model): - __tablename__ = "items" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "items" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(autouse=True, scope="module") diff --git a/tests/test_same_table_joins.py b/tests/test_same_table_joins.py index 60ddddc..166c943 100644 --- a/tests/test_same_table_joins.py +++ b/tests/test_same_table_joins.py @@ -12,53 +12,58 @@ metadata = sqlalchemy.MetaData() class Department(ormar.Model): - __tablename__ = "departments" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "departments" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True, autoincrement=False) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True, autoincrement=False) + name: ormar.String(max_length=100) class SchoolClass(ormar.Model): - __tablename__ = "schoolclasses" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "schoolclasses" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - department = ormar.ForeignKey(Department, nullable=False) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + department: ormar.ForeignKey(Department, nullable=False) class Category(ormar.Model): - __tablename__ = "categories" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "categories" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) class Student(ormar.Model): - __tablename__ = "students" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "students" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) class Teacher(ormar.Model): - __tablename__ = "teachers" - __metadata__ = metadata - __database__ = database + class Meta: + tablename = "teachers" + metadata = metadata + database = database - id = ormar.Integer(primary_key=True) - name = ormar.String(length=100) - schoolclass = ormar.ForeignKey(SchoolClass) - category = ormar.ForeignKey(Category, nullable=True) + id: ormar.Integer(primary_key=True) + name: ormar.String(max_length=100) + schoolclass: ormar.ForeignKey(SchoolClass) + category: ormar.ForeignKey(Category, nullable=True) @pytest.fixture(scope="module")