import datetime
from copy import deepcopy
from asyncpg.exceptions import InsufficientPrivilegeError
from asyncorm.exceptions import AsyncOrmModelError, AsyncOrmMultipleObjectsReturned, AsyncOrmQuerysetError
from asyncorm.manager.constants import LOOKUP_OPERATOR
from asyncorm.models.fields import CharField, ForeignKey, ManyToManyField, NumberField
[docs]class Queryset(object):
db_backend = None
orm = None
def __init__(self, model):
self.model = model
self.table_name = self.model.cls_tablename()
self.select = "*"
self.query = None
self._cursor = None
self._results = []
self.forward = 0
self.stop = None
self.step = None
[docs] def query_copy(self):
return self.query and deepcopy(self.query) or deepcopy(self.basic_query)
@property
def basic_query(self):
return [
{
"action": "_db__select_all",
"select": "*",
"table_name": self.model.cls_tablename(),
"ordering": self.model.ordering,
"join": "",
}
]
[docs] @classmethod
def set_orm(cls, orm):
cls.orm = orm
cls.db_backend = orm.db_backend
[docs] def get_field_queries(self):
"""Builds the creationquery for each of the non fk or m2m fields"""
return ", ".join(
[
f.creation_query()
for f in self.model.fields.values()
if not isinstance(f, (ManyToManyField, ForeignKey))
]
)
[docs] def create_table_builder(self):
return [
{
"table_name": self.model.cls_tablename(),
"action": "_db__create_table",
"field_queries": self.get_field_queries(),
}
]
[docs] async def create_table(self):
"""Builds the table without the m2m_fields and fks"""
await self.db_request(self.create_table_builder())
[docs] async def set_requirements(self):
"""Add to the database the table requirements if needed"""
try:
for query in self.model.field_requirements:
await self.db_backend.request(query)
except InsufficientPrivilegeError:
raise AsyncOrmModelError("Not enough privileges to add the needed requirement in the database")
[docs] def unique_together_builder(self):
unique_together = self.get_unique_together()
if unique_together:
return [
{
"table_name": self.model.cls_tablename(),
"action": "_db__constrain_table",
"constrain": unique_together,
}
]
return None
[docs] async def unique_together(self):
"""Builds the unique together constraint"""
db_request = self.unique_together_builder()
if db_request:
await self.db_request(db_request)
[docs] def add_fk_field_builder(self, field):
return [
{
"table_name": self.model.cls_tablename(),
"action": "db__table_add_column",
"field_creation_string": field.creation_query(),
}
]
[docs] async def add_fk_columns(self):
"""
Builds the fk fields
"""
for f in self.model.fields.values():
if isinstance(f, ForeignKey):
await self.db_request(self.add_fk_field_builder(f))
@staticmethod
def _add_m2m_columns_builder(field):
return [
{"table_name": field.table_name, "action": "_db__create_table", "field_queries": field.creation_query()}
]
@staticmethod
def _add_table_indices_builder(field):
return [
{
"index_name": "idx_{}_{}".format(field.table_name, field.orm_field_name)[:30],
"table_name": field.table_name,
"action": "_db__create_field_index",
"colum_name": field.orm_field_name,
}
]
[docs] async def add_m2m_columns(self):
"""
Builds the m2m_fields
"""
for f in self.model.fields.values():
if isinstance(f, ManyToManyField):
await self.db_request(self._add_m2m_columns_builder(f))
[docs] async def add_table_indices(self):
for f in self.model.fields.values():
if f.db_index:
await self.db_request(self._add_table_indices_builder(f))
[docs] def get_unique_together(self):
# builds the table with all its fields definition
unique_string = " UNIQUE ({}) ".format(",".join(self.model.unique_together))
return self.model.unique_together and unique_string or ""
[docs] def modelconstructor(self, record, instance=None):
if not instance:
instance = self.model()
data = {}
for k, v in record.items():
select_related = []
splitted = k.split("__")
if len(splitted) > 1:
if splitted[0] not in select_related:
select_related.append(splitted[0])
else:
data.update({k: v})
if select_related:
pass
instance.construct(data, subitems=self.query)
return instance
[docs] async def count(self):
query = self.query_copy()
query[0]["select"] = "COUNT(*)"
resp = await self.db_request(query)
for v in resp.values():
return v
[docs] async def exists(self):
query = self.query_copy()
query[0]["action"] = "_db__exists"
resp = await self.db_request(query)
for v in resp.values():
return v
[docs] async def calculate(self, field_name, operation):
if hasattr(self.model, field_name):
field = getattr(self.model, field_name)
else:
raise AsyncOrmQuerysetError("{} wrong field name for model {}".format(field_name, self.model.__name__))
if not isinstance(field, NumberField):
raise AsyncOrmQuerysetError("{} is not a numeric field".format(field_name))
query = self.query_copy()
query[0]["select"] = "{}({})".format(operation, field_name)
resp = await self.db_request(query)
for v in resp.values():
return v
[docs] async def Max(self, field_name):
return await self.calculate(field_name, "MAX")
[docs] async def Min(self, field_name):
return await self.calculate(field_name, "MIN")
[docs] async def Sum(self, field_name):
return await self.calculate(field_name, "SUM")
[docs] async def Avg(self, field_name):
return await self.calculate(field_name, "AVG")
[docs] async def StdDev(self, field_name):
return await self.calculate(field_name, "STDDEV")
[docs] async def get(self, **kwargs):
count = 0
queryset = self.queryset().filter(**kwargs)
async for itm in queryset:
count += 1
if count > 1:
raise AsyncOrmMultipleObjectsReturned(
'More than one "{}" were returned, there are {}!'.format(self.model.__name__, count)
)
elif count == 0:
raise self.model.DoesNotExist("That {} does not exist".format(self.model.__name__))
return itm
# CHAINABLE QUERYSET METHODS
[docs] def queryset(self):
return self._copy_me()
[docs] def all(self):
return self._copy_me()
[docs] def none(self):
queryset = self._copy_me()
kwargs = {self.model.db_pk: -1}
return queryset.filter(**kwargs)
[docs] def calc_filters(self, kwargs, exclude):
# recompose the filters
bool_string = exclude and "NOT " or ""
filters = []
for k, v in kwargs.items():
# we format the key, the conditional and the value
operator = "{t_n}.{k} = {v}"
lookup = None
if len(k.split("__")) > 1:
k, lookup = k.split("__")
operator = LOOKUP_OPERATOR[lookup]
field = getattr(self.model, k)
string_lookups = [
"exact",
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
]
operator_formater = {
"t_n": self.model.table_name or self.model.__name__.lower(),
"k": field.db_column,
"v": v,
}
if operator == "({t_n}.{k}>={min} AND {t_n}.{k}<={max})":
if not isinstance(v, (tuple, list)):
raise AsyncOrmQuerysetError("{} should be list or a tuple".format(lookup))
if len(v) != 2:
raise AsyncOrmQuerysetError("Not a correct tuple/list definition, should be of size 2")
operator_formater.update({"min": field.sanitize_data(v[0]), "max": field.sanitize_data(v[1])})
elif lookup in string_lookups:
is_charfield = isinstance(field, CharField)
# is_othercharfield = issubclass(field, CharField)
# if not is_charfield or not is_othercharfield:
if not is_charfield:
raise AsyncOrmQuerysetError("{} not allowed in non CharField fields".format(lookup))
operator_formater["v"] = field.sanitize_data(v)
else:
if isinstance(v, (list, tuple)):
# check they are correct items and serialize
v = ",".join(
["'{}'".format(field.sanitize_data(si)) if isinstance(si, str) else str(si) for si in v]
)
elif v is None:
v = field.sanitize_data(v)[1:-1]
operator = operator.replace("=", "IS")
elif isinstance(v, (datetime.datetime, datetime.date)) or isinstance(field, (CharField)):
v = "'{}'".format(v)
else:
v = field.sanitize_data(v)
operator_formater["v"] = v
filters.append(bool_string + operator.format(**operator_formater))
return filters
[docs] def filter(self, exclude=False, **kwargs):
filters = self.calc_filters(kwargs, exclude)
condition = " AND ".join(filters)
queryset = self.queryset()
queryset.query.append({"action": "_db__where", "condition": condition})
return queryset
[docs] def exclude(self, **kwargs):
return self.filter(exclude=True, **kwargs)
[docs] def only(self, *args):
# retrieves from the database only the attrs requested
# all the rest come as None
for arg in args:
if not hasattr(self.model, arg):
raise AsyncOrmQuerysetError("{} is not a correct field for {}".format(arg, self.model.__name__))
queryset = self.queryset()
queryset.query = self.query_copy()
queryset.query[0]["select"] = ",".join(args)
return queryset
[docs] def order_by(self, *args):
# retrieves from the database only the attrs requested
# all the rest come as None
final_args = []
for arg in args:
if arg[0] == "-":
arg = arg[1:]
final_args.append("-" + arg)
else:
final_args.append(arg)
if not hasattr(self.model, arg):
raise AsyncOrmQuerysetError("{} is not a correct field for {}".format(arg, self.model.__name__))
queryset = self.queryset()
queryset.query = self.query_copy()
queryset.query[0]["ordering"] = final_args
return queryset
# DB RELATED METHODS
[docs] async def db_request(self, db_request):
db_request = deepcopy(db_request)
db_request[0].update(
{
"select": db_request[0].get("select", self.select),
"table_name": db_request[0].get("table_name", self.model.cls_tablename()),
}
)
query = self.db_backend._construct_query(db_request)
return await self.db_backend.request(query)
def _get_queryset_slice(self, queryset_slice):
"""Private method to get a slice given the original queryset.
:param queryset_slice: Slice to be retrieved
:type queryset_slice: slice
:return: The slice of the queryset
:rtype: Queryset
"""
self.forward = queryset_slice.start
self.stop = queryset_slice.stop
if queryset_slice.start is None:
self.forward = 0
return self
async def _get_item(self, key):
"""Return the item selected from the iterator.
:param key: The position in the slice
:type key: int
:raises IndexError: When the item selected does not exist in the slice
:return: Model object from the Queryset
:rtype: Model
"""
if not self._cursor:
self._cursor = await self.db_backend.get_cursor(deepcopy(self.query), forward=key, stop=None)
async for res in self._cursor:
return self.modelconstructor(res)
raise IndexError("That {} index does not exist".format(self.model.__name__))
async def __getitem__(self, key):
if isinstance(key, slice):
wrong_start_key = key.start is not None and key.start < 0
wrong_stop_key = key.stop is not None and key.stop < 0
if wrong_start_key or wrong_stop_key:
raise AsyncOrmQuerysetError("Negative indices are not allowed")
if key.step is not None:
raise AsyncOrmQuerysetError("Step on Queryset is not allowed")
return self._get_queryset_slice(key)
elif isinstance(key, int):
if key < 0:
raise AsyncOrmQuerysetError("Negative indices are not allowed")
return await self._get_item(key)
else:
raise TypeError("Invalid argument type.")
def __aiter__(self):
return self
async def __anext__(self):
if not self._cursor:
self._cursor = await self.db_backend.get_cursor(self.query, forward=self.forward, stop=self.stop)
async for rec in self._cursor:
item = self.modelconstructor(rec)
return item
raise StopAsyncIteration()