Skip to content

Add a default rule for custom blocks #3570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion pyomo/core/base/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,10 +2399,36 @@ def __init__(self, *args, **kwargs):
break


def _default_rule(model_options):
"""
Default rule for custom blocks

Parameters
----------
model_options : dict
Dictionary of options needed to construct the block model
"""

def _rule(blk, *args):
try:
# Attempt to build the model
blk.build(*args, **model_options)

except AttributeError:
# build method is not implemented in the BlockData class
# Returning an empty Pyomo Block
pass

return _rule


class CustomBlock(Block):
"""The base class used by instances of custom block components"""

def __init__(self, *args, **kwargs):
model_options = kwargs.pop("options", {})
kwargs.setdefault("rule", _default_rule(model_options))

if self._default_ctype is not None:
kwargs.setdefault('ctype', self._default_ctype)
Block.__init__(self, *args, **kwargs)
Expand Down Expand Up @@ -2431,7 +2457,20 @@ def declare_custom_block(name, new_ctype=None):
>>> @declare_custom_block(name="FooBlock")
... class FooBlockData(BlockData):
... # custom block data class
... pass
... # CustomBlock returns an empty block if `build` method is not implemented
... def build(self, *args, option_1, option_2):
... # args contains the index (for indexed blocks)
... # option_1, option_2, ... are additional arguments
... self.x = Var()
... self.cost = Param(initialize=option_1)

Usage:
>>> m = ConcreteModel()
>>> m.blk = FooBlock([1, 2], options={"option_1": 1, "option_2": 2})

Specify `rule` argument to ignore the default rule argument.
>>> m = ConcreteModel()
>>> m.blk = FooBlock([1, 2], rule=my_custom_block_rule)
"""

def block_data_decorator(block_data):
Expand Down
71 changes: 71 additions & 0 deletions pyomo/core/tests/unit/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -3057,6 +3057,77 @@ def pprint(self, ostream=None, verbose=False, prefix=""):
b.pprint(ostream=stream)
self.assertEqual(correct_s, stream.getvalue())

def test_custom_block_default_rule(self):
"""Tests the decorator with `build` method, but without options"""
@declare_custom_block("FooBlock")
class FooBlockData(BlockData):
def build(self, *args):
self.x = Var(list(args))
self.y = Var()

m = ConcreteModel()
m.blk_without_index = FooBlock()
m.blk_1 = FooBlock([1, 2, 3])
m.blk_2 = FooBlock([4, 5], [6, 7])

self.assertIn("x", m.blk_without_index.component_map())
self.assertIn("y", m.blk_without_index.component_map())
self.assertIn("x", m.blk_1[3].component_map())
self.assertIn("x", m.blk_2[4, 6].component_map())

self.assertEqual(len(m.blk_1), 3)
self.assertEqual(len(m.blk_2), 4)

self.assertEqual(len(m.blk_1[2].x), 1)
self.assertEqual(len(m.blk_2[4, 6].x), 2)

def test_custom_block_default_rule_options(self):
"""Tests the decorator with `build` method and model options"""
@declare_custom_block("FooBlock")
class FooBlockData(BlockData):
def build(self, *args, capex, opex):
self.x = Var(list(args))
self.y = Var()

self.capex = capex
self.opex = opex

options = {"capex": 42, "opex": 24}
m = ConcreteModel()
m.blk_without_index = FooBlock(options=options)
m.blk_1 = FooBlock([1, 2, 3], options=options)
m.blk_2 = FooBlock([4, 5], [6, 7], options=options)

self.assertEqual(m.blk_without_index.capex, 42)
self.assertEqual(m.blk_without_index.opex, 24)

self.assertEqual(m.blk_1[3].capex, 42)
self.assertEqual(m.blk_2[4, 7].opex, 24)

with self.assertRaises(TypeError):
# missing 2 required keyword arguments
m.blk_3 = FooBlock()

def test_custom_block_user_rule(self):
"""Tests if the default rule can be overwritten"""
@declare_custom_block("FooBlock")
class FooBlockData(BlockData):
def build(self, *args):
self.x = Var(list(args))
self.y = Var()

def _new_rule(blk):
blk.a = Var()
blk.b = Var()

m = ConcreteModel()
m.blk = FooBlock(rule=_new_rule)

self.assertNotIn("x", m.blk.component_map())
self.assertNotIn("y", m.blk.component_map())
self.assertIn("a", m.blk.component_map())
self.assertIn("b", m.blk.component_map())

def test_block_rules(self):
m = ConcreteModel()
m.I = Set()
Expand Down