Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 223 additions & 23 deletions src/util/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@
on this module work along the ORM of *all* supported versions.
"""

import collections
import logging
import multiprocessing
import os
import re
import sys
import uuid
from contextlib import contextmanager
from functools import wraps
from itertools import chain
from itertools import chain, repeat
from textwrap import dedent

try:
from concurrent.futures import ProcessPoolExecutor
except ImportError:
ProcessPoolExecutor = None

try:
from unittest.mock import patch
except ImportError:
Expand All @@ -27,9 +37,9 @@
except ImportError:
from odoo import SUPERUSER_ID
from odoo import fields as ofields
from odoo import modules, release
from odoo import modules, release, sql_db
except ImportError:
from openerp import SUPERUSER_ID, modules, release
from openerp import SUPERUSER_ID, modules, release, sql_db

try:
from openerp import fields as ofields
Expand All @@ -41,8 +51,8 @@
from .const import BIG_TABLE_THRESHOLD
from .exceptions import MigrationError
from .helpers import table_of_model
from .misc import chunks, log_progress, version_between, version_gte
from .pg import SQLStr, column_exists, format_query, get_columns, named_cursor
from .misc import chunks, log_progress, str2bool, version_between, version_gte
from .pg import SQLStr, column_exists, format_query, get_columns, get_max_workers, named_cursor

# python3 shims
try:
Expand All @@ -52,6 +62,10 @@

_logger = logging.getLogger(__name__)

UPG_PARALLEL_ITER_BROWSE = str2bool(os.environ.get("UPG_PARALLEL_ITER_BROWSE", "0"))
# FIXME: for CI! Remove before merge
UPG_PARALLEL_ITER_BROWSE = True


def env(cr):
"""
Expand Down Expand Up @@ -341,6 +355,26 @@ def get_ids():
cr.execute("DROP TABLE IF EXISTS _upgrade_rf")


def _mp_iter_browse_cb(ids_or_values, params):
me = _mp_iter_browse_cb
# init upon first call. Done here instead of initializer callback, because py3.6 doesn't have it
if not hasattr(me, "env"):
sql_db._Pool = None # children cannot borrow from copies of the same pool, it will cause protocol error
me.env = env(sql_db.db_connect(params["dbname"]).cursor())
me.env.clear()
# process
if params["mode"] == "browse":
getattr(
me.env[params["model_name"]].with_context(params["context"]).browse(ids_or_values), params["attr_name"]
)(*params["args"], **params["kwargs"])
if params["mode"] == "create":
new_ids = me.env[params["model_name"]].with_context(params["context"]).create(ids_or_values).ids
me.env.cr.commit()
if params["mode"] == "create":
return new_ids
return None


class iter_browse(object):
"""
Iterate over recordsets.
Expand Down Expand Up @@ -374,7 +408,8 @@ class iter_browse(object):

:param model: the model to iterate
:type model: :class:`odoo.model.Model`
:param list(int) ids: list of IDs of the records to iterate
:param iterable(int) ids: iterable of IDs of the records to iterate
:param str query: alternative to ids, SQL query that can produce them
:param int chunk_size: number of records to load in each iteration chunk, `200` by
default
:param logger: logger used to report the progress, by default
Expand All @@ -387,23 +422,106 @@ class iter_browse(object):
See also :func:`~odoo.upgrade.util.orm.env`
"""

__slots__ = ("_chunk_size", "_cr_uid", "_it", "_logger", "_model", "_patch", "_size", "_strategy")
__slots__ = (
"_chunk_size",
"_cr_uid",
"_ids",
"_it",
"_logger",
"_model",
"_patch",
"_query",
"_size",
"_strategy",
"_superchunk_size",
)

def __init__(self, model, *args, **kw):
assert len(args) in [1, 3] # either (cr, uid, ids) or (ids,)
self._model = model
self._cr_uid = args[:-1]
ids = args[-1]
self._size = len(ids)
self._ids = args[-1]
self._size = kw.pop("size", None)
self._query = kw.pop("query", None)
self._chunk_size = kw.pop("chunk_size", 200) # keyword-only argument
self._superchunk_size = self._chunk_size
self._logger = kw.pop("logger", _logger)
self._strategy = kw.pop("strategy", "flush")
assert self._strategy in {"flush", "commit"}
assert self._strategy in {"flush", "commit", "multiprocessing"}
if self._strategy == "multiprocessing":
if not ProcessPoolExecutor:
raise ValueError("multiprocessing strategy can not be used in scripts run by python2")
if UPG_PARALLEL_ITER_BROWSE:
self._superchunk_size = min(get_max_workers() * 10 * self._chunk_size, 1000000)
else:
self._strategy = "commit" # downgrade
if self._size > 100000:
_logger.warning(
"Browsing %d %s, which may take a long time. "
"This can be sped up by setting the env variable UPG_PARALLEL_ITER_BROWSE to 1. "
"If you do, be sure to examine the results carefully.",
self._size,
self._model._name,
)
else:
_logger.info(
"Caller requested multiprocessing strategy, but UPG_PARALLEL_ITER_BROWSE env var is not set. "
"Downgrading strategy to commit.",
)
if kw:
raise TypeError("Unknown arguments: %s" % ", ".join(kw))

if not (self._ids is None) ^ (self._query is None):
raise TypeError("Must be initialized using exactly one of `ids` or `query`")

if self._query:
self._ids_query()

if not self._size:
try:
self._size = len(self._ids)
except TypeError:
raise ValueError("When passing ids as a generator, the size kwarg is mandatory")
self._patch = None
self._it = chunks(ids, self._chunk_size, fmt=self._browse)
self._it = chunks(self._ids, self._chunk_size, fmt=self._browse)

def _ids_query(self):
cr = self._model.env.cr
tmp_tbl = "_upgrade_ib_{}".format(uuid.uuid4().hex)
cr.execute(
format_query(
cr,
"CREATE UNLOGGED TABLE {}(id) AS (WITH query AS ({}) SELECT * FROM query)",
tmp_tbl,
SQLStr(self._query),
)
)
self._size = cr.rowcount
cr.execute(
format_query(cr, "ALTER TABLE {} ADD CONSTRAINT {} PRIMARY KEY (id)", tmp_tbl, "pk_{}_id".format(tmp_tbl))
)

def get_ids():
with named_cursor(cr, itersize=self._superchunk_size) as ncr:
ncr.execute(format_query(cr, "SELECT id FROM {} ORDER BY id", tmp_tbl))
for (id_,) in ncr:
yield id_
cr.execute(format_query(cr, "DROP TABLE IF EXISTS {}", tmp_tbl))

self._ids = get_ids()

def _values_query(self, query):
cr = self._model.env.cr
cr.execute(format_query(cr, "WITH query AS ({}) SELECT count(*) FROM query", SQLStr(query)))
size = cr.fetchone()[0]

def get_values():
with named_cursor(cr, itersize=self._chunk_size) as ncr:
ncr.execute(SQLStr(query))
for row in ncr.iterdict():
yield row

return size, get_values()

def _browse(self, ids):
next(self._end(), None)
Expand All @@ -415,7 +533,7 @@ def _browse(self, ids):
return self._model.browse(*args)

def _end(self):
if self._strategy == "commit":
if self._strategy in ["commit", "multiprocessing"]:
self._model.env.cr.commit()
else:
flush(self._model)
Expand All @@ -430,8 +548,12 @@ def __iter__(self):
raise RuntimeError("%r ran twice" % (self,))

it = chain.from_iterable(self._it)
sz = self._size
if self._strategy == "multiprocessing":
it = self._it
sz = (self._size + self._chunk_size - 1) // self._chunk_size
if self._logger:
it = log_progress(it, self._logger, qualifier=self._model._name, size=self._size)
it = log_progress(it, self._logger, qualifier=self._model._name, size=sz)
self._it = None
return chain(it, self._end())

Expand All @@ -442,48 +564,90 @@ def __getattr__(self, attr):
if not callable(getattr(self._model, attr)):
raise TypeError("The attribute %r is not callable" % attr)

it = self._it
it = chunks(self._ids, self._superchunk_size, fmt=self._browse)
if self._logger:
sz = (self._size + self._chunk_size - 1) // self._chunk_size
qualifier = "%s[:%d]" % (self._model._name, self._chunk_size)
sz = (self._size + self._superchunk_size - 1) // self._superchunk_size
qualifier = "%s[:%d]" % (self._model._name, self._superchunk_size)
it = log_progress(it, self._logger, qualifier=qualifier, size=sz)

def caller(*args, **kwargs):
args = self._cr_uid + args
return [getattr(chnk, attr)(*args, **kwargs) for chnk in chain(it, self._end())]
if self._strategy != "multiprocessing":
return [getattr(chnk, attr)(*args, **kwargs) for chnk in chain(it, self._end())]
params = {
"dbname": self._model.env.cr.dbname,
"model_name": self._model._name,
# convert to dict for pickle. Will still break if any value in the context is not pickleable
"context": dict(self._model.env.context),
"attr_name": attr,
"args": args,
"kwargs": kwargs,
"mode": "browse",
}
self._model.env.cr.commit()
extrakwargs = {"mp_context": multiprocessing.get_context("fork")} if sys.version_info >= (3, 7) else {}
with ProcessPoolExecutor(max_workers=get_max_workers(), **extrakwargs) as executor:
for chunk in it:
collections.deque(
executor.map(
_mp_iter_browse_cb, chunks(chunk._ids, self._chunk_size, fmt=tuple), repeat(params)
),
maxlen=0,
)
next(self._end(), None)
# do not return results in // mode, we expect it to be used for huge numbers of
# records and thus would risk MemoryError, also we cannot know if what attr returns is pickleable
return None

self._it = None
return caller

def create(self, values, **kw):
def create(self, values=None, query=None, **kw):
"""
Create records.

An alternative to the default `create` method of the ORM that is safe to use to
create millions of records.

:param list(dict) values: list of values of the records to create
:param iterable(dict) values: iterable of values of the records to create
:param int size: the no. of elements produced by values, required if values is a generator
:param str query: alternative to values, SQL query that can produce them
:param bool multi: whether to use the multi version of `create`, by default is
`True` from Odoo 12 and above
"""
multi = kw.pop("multi", version_gte("saas~11.5"))
size = kw.pop("size", None)
if kw:
raise TypeError("Unknown arguments: %s" % ", ".join(kw))

if not values:
raise ValueError("`create` cannot be called with an empty `values` argument")
if not (values is None) ^ (query is None):
raise ValueError("`create` needs to be called using exactly one of `values` or `query` arguments")

if self._size:
raise ValueError("`create` can only called on empty `browse_record` objects.")

ids = []
size = len(values)
if query:
size, values = self._values_query(query)

if size is None:
try:
size = len(values)
except TypeError:
raise ValueError("When passing a generator of values, the size kwarg is mandatory")

if self._strategy == "multiprocessing":
return self._create_multiprocess(values, size, multi)

return self._create(values, size, multi)

def _create(self, values, size, multi):
it = chunks(values, self._chunk_size, fmt=list)
if self._logger:
sz = (size + self._chunk_size - 1) // self._chunk_size
qualifier = "env[%r].create([:%d])" % (self._model._name, self._chunk_size)
it = log_progress(it, self._logger, qualifier=qualifier, size=sz)

ids = []
self._patch = no_selection_cache_validation()
for sub_values in it:
self._patch.start()
Expand All @@ -502,6 +666,42 @@ def create(self, values, **kw):
self._model, *args, chunk_size=self._chunk_size, logger=self._logger, strategy=self._strategy
)

def _create_multiprocess(self, values, size, multi):
if not multi:
raise ValueError("The multiprocessing strategy only supports the multi version of `create`")

it = chunks(values, self._superchunk_size, fmt=list)
if self._logger:
sz = (size + self._superchunk_size - 1) // self._superchunk_size
qualifier = "env[%r].create([:%d])" % (self._model._name, self._superchunk_size)
it = log_progress(it, self._logger, qualifier=qualifier, size=sz)

def iter_proc():
params = {
"dbname": self._model.env.cr.dbname,
"model_name": self._model._name,
# convert to dict for pickle. Will still break if any value in the context is not pickleable
"context": dict(self._model.env.context),
"mode": "create",
}
self._model.env.cr.commit()
self._patch.start()
extrakwargs = {"mp_context": multiprocessing.get_context("fork")} if sys.version_info >= (3, 7) else {}
with ProcessPoolExecutor(max_workers=get_max_workers(), **extrakwargs) as executor:
for sub_values in it:
for task_result in executor.map(
_mp_iter_browse_cb, chunks(sub_values, self._chunk_size, fmt=tuple), repeat(params)
):
self._model.env.cr.commit() # make task_result visible on main cursor before yielding ids
for new_id in task_result:
yield new_id
next(self._end(), None)

self._patch = no_selection_cache_validation()
args = self._cr_uid + (iter_proc(),)
kwargs = {"size": size, "chunk_size": self._chunk_size, "logger": None, "strategy": self._strategy}
return iter_browse(self._model, *args, **kwargs)


@contextmanager
def custom_module_field_as_manual(env, rollback=True, do_flush=False):
Expand Down