diff --git a/doc/index.rst b/doc/index.rst index 2591021b..27a511a4 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -109,6 +109,7 @@ This package is published under MIT license. creating-variables creating-expressions creating-constraints + sos-constraints manipulating-models testing-framework transport-tutorial diff --git a/doc/sos-constraints.rst b/doc/sos-constraints.rst new file mode 100644 index 00000000..aa9d1bd2 --- /dev/null +++ b/doc/sos-constraints.rst @@ -0,0 +1,303 @@ +.. _sos-constraints: + +Special Ordered Sets (SOS) Constraints +======================================= + +Special Ordered Sets (SOS) are a constraint type used in mixed-integer programming to model situations where only one or two variables from an ordered set can be non-zero. Linopy supports both SOS Type 1 and SOS Type 2 constraints. + +.. contents:: + :local: + :depth: 2 + +Overview +-------- + +SOS constraints are particularly useful for: + +- **SOS1**: Modeling mutually exclusive choices (e.g., selecting one facility from multiple locations) +- **SOS2**: Piecewise linear approximations of nonlinear functions +- Improving branch-and-bound efficiency in mixed-integer programming + +Types of SOS Constraints +------------------------- + +SOS Type 1 (SOS1) +~~~~~~~~~~~~~~~~~~ + +In an SOS1 constraint, **at most one** variable in the ordered set can be non-zero. + +**Example use cases:** +- Facility location problems (choose one location among many) +- Technology selection (choose one technology option) +- Mutually exclusive investment decisions + +SOS Type 2 (SOS2) +~~~~~~~~~~~~~~~~~~ + +In an SOS2 constraint, **at most two adjacent** variables in the ordered set can be non-zero. The adjacency is determined by the ordering weights (coordinates) of the variables. + +**Example use cases:** +- Piecewise linear approximation of nonlinear functions +- Portfolio optimization with discrete risk levels +- Production planning with discrete capacity levels + +Basic Usage +----------- + +Adding SOS Constraints +~~~~~~~~~~~~~~~~~~~~~~~ + +To add SOS constraints to variables in linopy: + +.. code-block:: python + + import linopy + import pandas as pd + import xarray as xr + + # Create model + m = linopy.Model() + + # Create variables with numeric coordinates + coords = pd.Index([0, 1, 2], name="options") + x = m.add_variables(coords=[coords], name="x", lower=0, upper=1) + + # Add SOS1 constraint + m.add_sos_constraints(x, sos_type=1, sos_dim="options") + + # For SOS2 constraint + breakpoints = pd.Index([0.0, 1.0, 2.0], name="breakpoints") + lambdas = m.add_variables(coords=[breakpoints], name="lambdas", lower=0, upper=1) + m.add_sos_constraints(lambdas, sos_type=2, sos_dim="breakpoints") + +Method Signature +~~~~~~~~~~~~~~~~ + +.. code-block:: python + + Model.add_sos_constraints(variable, sos_type, sos_dim) + +**Parameters:** + +- ``variable`` : Variable + The variable to which the SOS constraint should be applied +- ``sos_type`` : {1, 2} + Type of SOS constraint (1 or 2) +- ``sos_dim`` : str + Name of the dimension along which the SOS constraint applies + +**Requirements:** + +- The specified dimension must exist in the variable +- The coordinates for the SOS dimension must be numeric (used as weights for ordering) +- Only one SOS constraint can be applied per variable + +Examples +-------- + +Example 1: Facility Location (SOS1) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import linopy + import pandas as pd + import xarray as xr + + # Problem data + locations = pd.Index([0, 1, 2, 3], name="locations") + costs = xr.DataArray([100, 150, 120, 80], coords=[locations]) + benefits = xr.DataArray([200, 300, 250, 180], coords=[locations]) + + # Create model + m = linopy.Model() + + # Decision variables: build facility at location i + build = m.add_variables(coords=[locations], name="build", lower=0, upper=1) + + # SOS1 constraint: at most one facility can be built + m.add_sos_constraints(build, sos_type=1, sos_dim="locations") + + # Objective: maximize net benefit + net_benefit = benefits - costs + m.add_objective(-((net_benefit * build).sum())) + + # Solve + m.solve(solver_name="highs") + + if m.status == "ok": + solution = build.solution.to_pandas() + selected_location = solution[solution > 0.5].index[0] + print(f"Build facility at location {selected_location}") + +Example 2: Piecewise Linear Approximation (SOS2) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import numpy as np + + # Approximate f(x) = x² over [0, 3] with breakpoints + breakpoints = pd.Index([0, 1, 2, 3], name="breakpoints") + + x_vals = xr.DataArray(breakpoints.to_series()) + y_vals = x_vals**2 + + # Create model + m = linopy.Model() + + # SOS2 variables (interpolation weights) + lambdas = m.add_variables(lower=0, upper=1, coords=[breakpoints], name="lambdas") + m.add_sos_constraints(lambdas, sos_type=2, sos_dim="breakpoints") + + # Interpolated coordinates + x = m.add_variables(name="x", lower=0, upper=3) + y = m.add_variables(name="y", lower=0, upper=9) + + # Constraints + m.add_constraints(lambdas.sum() == 1, name="convexity") + m.add_constraints(x == lambdas @ x_vals, name="x_interpolation") + m.add_constraints(y == lambdas @ y_vals, name="y_interpolation") + m.add_constraints(x >= 1.5, name="x_minimum") + + # Objective: minimize approximated function value + m.add_objective(y) + + # Solve + m.solve(solver_name="highs") + +Working with Multi-dimensional Variables +----------------------------------------- + +SOS constraints are created for each dimension that is not sos_dim. + +.. code-block:: python + + # Multi-period production planning + periods = pd.Index(range(3), name="periods") + modes = pd.Index([0, 1, 2], name="modes") + + # 2D variables: periods × modes + period_modes = m.add_variables( + lower=0, upper=1, coords=[periods, modes], name="use_mode" + ) + + # Adds SOS1 constraint for each period + m.add_sos_constraints(period_modes, sos_type=1, sos_dim="modes") + +Accessing SOS Variables +----------------------- + +You can easily identify and access variables with SOS constraints: + +.. code-block:: python + + # Get all variables with SOS constraints + sos_variables = m.variables.sos + print(f"SOS variables: {list(sos_variables.keys())}") + + # Check SOS properties of a variable + for var_name in sos_variables: + var = m.variables[var_name] + sos_type = var.attrs["sos_type"] + sos_dim = var.attrs["sos_dim"] + print(f"{var_name}: SOS{sos_type} on dimension '{sos_dim}'") + +Variable Representation +~~~~~~~~~~~~~~~~~~~~~~~ + +Variables with SOS constraints show their SOS information in string representations: + +.. code-block:: python + + print(build) + # Output: Variable (locations: 4) - sos1 on locations + # ----------------------------------------------- + # [0]: build[0] ∈ [0, 1] + # [1]: build[1] ∈ [0, 1] + # [2]: build[2] ∈ [0, 1] + # [3]: build[3] ∈ [0, 1] + +LP File Export +-------------- + +The generated LP file will include a SOS section: + +.. code-block:: text + + sos + + s0: S1 :: x0:0 x1:1 x2:2 + s3: S2 :: x3:0.0 x4:1.0 x5:2.0 + +Solver Compatibility +-------------------- + +SOS constraints are supported by most modern mixed-integer programming solvers through the LP file format: + +**Supported solvers:** +- HiGHS +- Gurobi +- CPLEX +- COIN-OR CBC +- SCIP +- Xpress + +**Note:** Some solvers may have varying levels of SOS support. Check your solver's documentation for specific capabilities. + +Common Patterns +--------------- + +Piecewise Linear Cost Function +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + def add_piecewise_cost(model, variable, breakpoints, costs): + """Add piecewise linear cost function using SOS2.""" + n_segments = len(breakpoints) + lambda_coords = pd.Index(range(n_segments), name="segments") + + lambdas = model.add_variables( + coords=[lambda_coords], name="cost_lambdas", lower=0, upper=1 + ) + model.add_sos_constraints(lambdas, sos_type=2, sos_dim="segments") + + cost_var = model.add_variables(name="cost", lower=0) + + x_vals = xr.DataArray(breakpoints, coords=[lambda_coords]) + c_vals = xr.DataArray(costs, coords=[lambda_coords]) + + model.add_constraints(lambdas.sum() == 1, name="cost_convexity") + model.add_constraints(variable == (x_vals * lambdas).sum(), name="cost_x_def") + model.add_constraints(cost_var == (c_vals * lambdas).sum(), name="cost_def") + + return cost_var + +Mutually Exclusive Investments +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code-block:: python + + def add_exclusive_investments(model, projects, costs, returns): + """Add mutually exclusive investment decisions using SOS1.""" + project_coords = pd.Index(projects, name="projects") + + invest = model.add_variables( + coords=[project_coords], name="invest", binary=True + ) + model.add_sos_constraints(invest, sos_type=1, sos_dim="projects") + + total_cost = (invest * costs).sum() + total_return = (invest * returns).sum() + + return invest, total_cost, total_return + + +See Also +-------- + +- :doc:`creating-variables`: Creating variables with coordinates +- :doc:`creating-constraints`: Adding regular constraints +- :doc:`user-guide`: General linopy usage patterns +- Example notebook: ``examples/sos-constraints-example.ipynb`` diff --git a/linopy/io.py b/linopy/io.py index 7065adbb..d9e0f2a8 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -23,6 +23,7 @@ from tqdm import tqdm from linopy import solvers +from linopy.common import to_polars from linopy.constants import CONCAT_DIM from linopy.objective import Objective @@ -327,6 +328,66 @@ def integers_to_file( formatted.write_csv(f, **kwargs) +def sos_to_file( + m: Model, + f: BufferedWriter, + progress: bool = False, + slice_size: int = 2_000_000, + explicit_coordinate_names: bool = False, +) -> None: + """ + Write out integers of a model to a lp file. + """ + names = m.variables.sos + if not len(list(names)): + return + + print_variable, _ = get_printers( + m, explicit_coordinate_names=explicit_coordinate_names + ) + + f.write(b"\n\nsos\n\n") + if progress: + names = tqdm( + list(names), + desc="Writing sos constraints.", + colour=TQDM_COLOR, + ) + + for name in names: + var = m.variables[name] + sos_type = var.attrs["sos_type"] + sos_dim = var.attrs["sos_dim"] + + other_dims = tuple([dim for dim in var.labels.dims if dim != sos_dim]) + for var_slice in var.iterate_slices(slice_size, other_dims): + ds = var_slice.labels.to_dataset() + ds["sos_labels"] = ds["labels"].isel({sos_dim: 0}) + ds["weights"] = ds.coords[sos_dim] + df = to_polars(ds) + + df = df.group_by("sos_labels").agg( + pl.concat_str( + *print_variable(pl.col("labels")), pl.lit(":"), pl.col("weights") + ) + .str.join(" ") + .alias("var_weights") + ) + + columns = [ + pl.lit("s"), + pl.col("sos_labels"), + pl.lit(f": S{sos_type} :: "), + pl.col("var_weights"), + ] + + kwargs: Any = dict( + separator=" ", null_value="", quote_style="never", include_header=False + ) + formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) + formatted.write_csv(f, **kwargs) + + def constraints_to_file( m: Model, f: BufferedWriter, @@ -464,6 +525,13 @@ def to_lp_file( slice_size=slice_size, explicit_coordinate_names=explicit_coordinate_names, ) + sos_to_file( + m, + f=f, + progress=progress, + slice_size=slice_size, + explicit_coordinate_names=explicit_coordinate_names, + ) f.write(b"end\n") logger.info(f" Writing time: {round(time.time() - start, 2)}s") @@ -683,6 +751,23 @@ def to_gurobipy( c = model.addMConstr(M.A, x, M.sense, M.b) # type: ignore c.setAttr("ConstrName", list(names)) # type: ignore + if m.variables.sos: + for var_name in m.variables.sos: + var = m.variables.sos[var_name] + sos_type = var.attrs["sos_type"] + sos_dim = var.attrs["sos_dim"] + + def add_sos(s): + s = s.squeeze() + model.addSOS(sos_type, x[s].tolist(), s.coords[sos_dim].values) + + others = tuple(dim for dim in var.labels.dims if dim != sos_dim) + if not others: + add_sos(var.labels) + else: + for _, s in var.labels.groupby(*others): + add_sos(s) + model.update() return model diff --git a/linopy/model.py b/linopy/model.py index 149c2cc2..38cb09eb 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -12,7 +12,7 @@ from collections.abc import Callable, Mapping, Sequence from pathlib import Path from tempfile import NamedTemporaryFile, gettempdir -from typing import Any, overload +from typing import Any, Literal, overload import numpy as np import pandas as pd @@ -551,6 +551,46 @@ def add_variables( self.variables.add(variable) return variable + def add_sos_constraints( + self, + variable: Variable, + sos_type: Literal[1, 2], + sos_dim: str, + ): + """ + Add an sos1 or sos2 constraint for one dimension of a variable + + The dimension values are used as SOS. + + Parameters + ---------- + variable : Variable + sos_type : {1, 2} + Type of SOS + sos_dim : str + Which dimension of variable to add SOS constraint to + """ + if sos_type not in (1, 2): + raise ValueError(f"sos_type must be 1 or 2, got {sos_type}") + if sos_dim not in variable.dims: + raise ValueError(f"sos_dim must name a variable dimension, got {sos_dim}") + + if "sos_type" in variable.attrs or "sos_dim" in variable.attrs: + existing_sos_type = variable.attrs.get("sos_type") + existing_sos_dim = variable.attrs.get("sos_dim") + raise ValueError( + f"variable already has an sos{existing_sos_type} constraint on {existing_sos_dim}" + ) + + # Validate that sos_dim coordinates are numeric (needed for weights) + if not pd.api.types.is_numeric_dtype(variable.coords[sos_dim]): + raise ValueError( + f"SOS constraint requires numeric coordinates for dimension '{sos_dim}', " + f"but got {variable.coords[sos_dim].dtype}" + ) + + variable.attrs.update(sos_type=sos_type, sos_dim=sos_dim) + def add_constraints( self, lhs: VariableLike @@ -776,6 +816,29 @@ def remove_constraints(self, name: str | list[str]) -> None: logger.debug(f"Removed constraint: {name}") self.constraints.remove(name) + def remove_sos_constraints(self, variable: Variable) -> None: + """ + Remove all sos constraints from a given variable. + + Parameters + ---------- + variable : Variable + Variable instance from which to remove all sos constraints. + Can be retrieved from `m.variables.sos`. + + Returns + ------- + None. + """ + sos_type = variable.attrs.get("sos_type") + sos_dim = variable.attrs.get("sos_dim") + + del variable.attrs["sos_type"], variable.attrs["sos_dim"] + + logger.debug( + f"Removed sos{sos_type} constraint on {sos_dim} from {variable.name}" + ) + def remove_objective(self) -> None: """ Remove the objective's linear expression from the model. diff --git a/linopy/variables.py b/linopy/variables.py index 2a929515..2ad76012 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -193,6 +193,14 @@ def __init__( if "label_range" not in data.attrs: data.assign_attrs(label_range=(data.labels.min(), data.labels.max())) + if "sos_type" in data.attrs or "sos_dim" in data.attrs: + if (sos_type := data.attrs.get("sos_type")) not in (1, 2): + raise ValueError(f"sos_type must be 1 or 2, got {sos_type}") + if (sos_dim := data.attrs.get("sos_dim")) not in data.dims: + raise ValueError( + f"sos_dim must name a variable dimension, got {sos_dim}" + ) + self._data = data self._model = model @@ -323,6 +331,8 @@ def __repr__(self) -> str: dim_names = self.coord_names dim_sizes = list(self.sizes.values()) masked_entries = (~self.mask).sum().values + sos_type = self.attrs.get("sos_type") + sos_dim = self.attrs.get("sos_dim") lines = [] if dims: @@ -344,9 +354,11 @@ def __repr__(self) -> str: shape_str = ", ".join(f"{d}: {s}" for d, s in zip(dim_names, dim_sizes)) mask_str = f" - {masked_entries} masked entries" if masked_entries else "" + sos_str = f" - sos{sos_type} on {sos_dim}" if sos_type and sos_dim else "" lines.insert( 0, - f"Variable ({shape_str}){mask_str}\n{'-' * (len(shape_str) + len(mask_str) + 11)}", + f"Variable ({shape_str}){mask_str}{sos_str}\n" + f"{'-' * (len(shape_str) + len(mask_str) + len(sos_str) + 11)}", ) else: lines.append( @@ -1232,6 +1244,10 @@ def __repr__(self) -> str: if ds.coords else "" ) + if (sos_type := ds.attrs.get("sos_type")) in (1, 2) and ( + sos_dim := ds.attrs.get("sos_dim") + ): + coords += f" - sos{sos_type} on {sos_dim}" r += f" * {name}{coords}\n" if not len(list(self)): r += "\n" @@ -1362,6 +1378,21 @@ def continuous(self) -> Variables: self.model, ) + @property + def sos(self) -> Variables: + """ + Get all variables involved in an sos constraint. + """ + return self.__class__( + { + name: self.data[name] + for name in self + if self[name].attrs.get("sos_dim") + and self[name].attrs.get("sos_type") in (1, 2) + }, + self.model, + ) + @property def solution(self) -> Dataset: """ diff --git a/test/test_sos_constraints.py b/test/test_sos_constraints.py new file mode 100644 index 00000000..b4e4dc3f --- /dev/null +++ b/test/test_sos_constraints.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import pandas as pd +import pytest + +from linopy import Model, available_solvers + + +def test_add_sos_constraints_registers_variable() -> None: + m = Model() + locations = pd.Index([0, 1, 2], name="locations") + build = m.add_variables(coords=[locations], name="build") + + m.add_sos_constraints(build, sos_type=1, sos_dim="locations") + + assert build.attrs["sos_type"] == 1 + assert build.attrs["sos_dim"] == "locations" + assert list(m.variables.sos) == ["build"] + + m.remove_sos_constraints(build) + assert "sos_type" not in build.attrs + assert "sos_dim" not in build.attrs + + +def test_add_sos_constraints_validation() -> None: + m = Model() + strings = pd.Index(["a", "b"], name="strings") + with pytest.raises(ValueError, match="sos_type"): + m.add_sos_constraints(m.add_variables(name="x"), sos_type=3, sos_dim="i") + + variable = m.add_variables(coords=[strings], name="string_var") + + with pytest.raises(ValueError, match="dimension"): + m.add_sos_constraints(variable, sos_type=1, sos_dim="missing") + + with pytest.raises(ValueError, match="numeric"): + m.add_sos_constraints(variable, sos_type=1, sos_dim="strings") + + numeric = m.add_variables(coords=[pd.Index([0, 1], name="dim")], name="num") + m.add_sos_constraints(numeric, sos_type=1, sos_dim="dim") + with pytest.raises(ValueError, match="already has"): + m.add_sos_constraints(numeric, sos_type=1, sos_dim="dim") + + +def test_sos_constraints_written_to_lp(tmp_path) -> None: + m = Model() + breakpoints = pd.Index([0.0, 1.5, 3.5], name="bp") + lambdas = m.add_variables(coords=[breakpoints], name="lambda") + m.add_sos_constraints(lambdas, sos_type=2, sos_dim="bp") + + fn = tmp_path / "sos.lp" + m.to_file(fn, io_api="lp") + content = fn.read_text() + + assert "\nsos\n" in content + assert "S2 ::" in content + assert "3.5" in content + + +@pytest.mark.skipif("gurobi" not in available_solvers, reason="Gurobipy not installed") +def test_to_gurobipy_emits_sos_constraints() -> None: + gurobipy = pytest.importorskip("gurobipy") + + m = Model() + segments = pd.Index([0.0, 0.5, 1.0], name="seg") + var = m.add_variables(coords=[segments], name="lambda") + m.add_sos_constraints(var, sos_type=1, sos_dim="seg") + + try: + model = m.to_gurobipy() + except gurobipy.GurobiError as exc: # pragma: no cover - depends on license setup + pytest.skip(f"Gurobi environment unavailable: {exc}") + + assert model.NumSOS == 1