Skip to content

Commit f09216e

Browse files
committed
add non-retryable error, add grpc options
Signed-off-by: Filinto Duran <[email protected]>
1 parent 92635d3 commit f09216e

File tree

5 files changed

+387
-19
lines changed

5 files changed

+387
-19
lines changed

durabletask/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def __init__(self, *,
9898
log_handler: Optional[logging.Handler] = None,
9999
log_formatter: Optional[logging.Formatter] = None,
100100
secure_channel: bool = False,
101-
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):
101+
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
102+
options: Optional[Sequence[tuple[str, Any]]] = None):
102103

103104
# If the caller provided metadata, we need to create a new interceptor for it and
104105
# add it to the list of interceptors.
@@ -114,7 +115,8 @@ def __init__(self, *,
114115
channel = shared.get_grpc_channel(
115116
host_address=host_address,
116117
secure_channel=secure_channel,
117-
interceptors=interceptors
118+
interceptors=interceptors,
119+
options=options,
118120
)
119121
self._stub = stubs.TaskHubSidecarServiceStub(channel)
120122
self._logger = shared.get_logger("client", log_handler, log_formatter)

durabletask/internal/shared.py

Lines changed: 127 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import os
88
from types import SimpleNamespace
9-
from typing import Any, Optional, Sequence, Union
9+
from typing import Any, Iterable, Optional, Sequence, Union
1010

1111
import grpc
1212

@@ -53,7 +53,8 @@ def get_default_host_address() -> str:
5353
def get_grpc_channel(
5454
host_address: Optional[str],
5555
secure_channel: bool = False,
56-
interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel:
56+
interceptors: Optional[Sequence[ClientInterceptor]] = None,
57+
options: Optional[Sequence[tuple[str, Any]]] = None) -> grpc.Channel:
5758
if host_address is None:
5859
host_address = get_default_host_address()
5960

@@ -71,18 +72,139 @@ def get_grpc_channel(
7172
host_address = host_address[len(protocol):]
7273
break
7374

74-
# Create the base channel
75+
# Build channel options (merge provided options with env-driven keepalive/retry)
76+
channel_options = build_grpc_channel_options(options)
77+
78+
# Create the base channel (preserve original call signature when no options)
7579
if secure_channel:
76-
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
80+
if channel_options is not None:
81+
channel = grpc.secure_channel(
82+
host_address, grpc.ssl_channel_credentials(), options=channel_options
83+
)
84+
else:
85+
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
7786
else:
78-
channel = grpc.insecure_channel(host_address)
87+
if channel_options is not None:
88+
channel = grpc.insecure_channel(host_address, options=channel_options)
89+
else:
90+
channel = grpc.insecure_channel(host_address)
7991

8092
# Apply interceptors ONLY if they exist
8193
if interceptors:
8294
channel = grpc.intercept_channel(channel, *interceptors)
8395
return channel
8496

8597

98+
def _get_env_bool(name: str, default: bool) -> bool:
99+
val = os.environ.get(name)
100+
if val is None:
101+
return default
102+
return val.strip().lower() in {"1", "true", "t", "yes", "y"}
103+
104+
105+
def _get_env_int(name: str, default: int) -> int:
106+
val = os.environ.get(name)
107+
if val is None:
108+
return default
109+
try:
110+
return int(val)
111+
except Exception:
112+
return default
113+
114+
115+
def _get_env_float(name: str, default: float) -> float:
116+
val = os.environ.get(name)
117+
if val is None:
118+
return default
119+
try:
120+
return float(val)
121+
except Exception:
122+
return default
123+
124+
125+
def _get_env_csv(name: str, default_csv: str) -> list[str]:
126+
val = os.environ.get(name, default_csv)
127+
return [s.strip().upper() for s in val.split(",") if s.strip()]
128+
129+
130+
def get_grpc_keepalive_options() -> list[tuple[str, Any]]:
131+
"""Build gRPC keepalive channel options from environment variables.
132+
133+
Environment variables (defaults in parentheses):
134+
- DAPR_GRPC_KEEPALIVE_ENABLED (false)
135+
- DAPR_GRPC_KEEPALIVE_TIME_MS (120000)
136+
- DAPR_GRPC_KEEPALIVE_TIMEOUT_MS (20000)
137+
- DAPR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS (false)
138+
"""
139+
enabled = _get_env_bool("DAPR_GRPC_KEEPALIVE_ENABLED", False)
140+
if not enabled:
141+
return []
142+
time_ms = _get_env_int("DAPR_GRPC_KEEPALIVE_TIME_MS", 120000)
143+
timeout_ms = _get_env_int("DAPR_GRPC_KEEPALIVE_TIMEOUT_MS", 20000)
144+
permit_without_calls = 1 if _get_env_bool("DAPR_GRPC_KEEPALIVE_PERMIT_WITHOUT_CALLS", False) else 0
145+
return [
146+
("grpc.keepalive_time_ms", time_ms),
147+
("grpc.keepalive_timeout_ms", timeout_ms),
148+
("grpc.keepalive_permit_without_calls", permit_without_calls),
149+
]
150+
151+
152+
def get_grpc_retry_service_config_option() -> Optional[tuple[str, str]]:
153+
"""Return ("grpc.service_config", json_str) if retry is enabled via env; else None.
154+
155+
Environment variables (defaults in parentheses):
156+
- DAPR_GRPC_RETRY_ENABLED (false)
157+
- DAPR_GRPC_RETRY_MAX_ATTEMPTS (4)
158+
- DAPR_GRPC_RETRY_INITIAL_BACKOFF_MS (100)
159+
- DAPR_GRPC_RETRY_MAX_BACKOFF_MS (1000)
160+
- DAPR_GRPC_RETRY_BACKOFF_MULTIPLIER (2.0)
161+
- DAPR_GRPC_RETRY_CODES (UNAVAILABLE,DEADLINE_EXCEEDED)
162+
"""
163+
enabled = _get_env_bool("DAPR_GRPC_RETRY_ENABLED", False)
164+
if not enabled:
165+
return None
166+
167+
max_attempts = _get_env_int("DAPR_GRPC_RETRY_MAX_ATTEMPTS", 4)
168+
initial_backoff_ms = _get_env_int("DAPR_GRPC_RETRY_INITIAL_BACKOFF_MS", 100)
169+
max_backoff_ms = _get_env_int("DAPR_GRPC_RETRY_MAX_BACKOFF_MS", 1000)
170+
backoff_multiplier = _get_env_float("DAPR_GRPC_RETRY_BACKOFF_MULTIPLIER", 2.0)
171+
codes = _get_env_csv("DAPR_GRPC_RETRY_CODES", "UNAVAILABLE,DEADLINE_EXCEEDED")
172+
173+
service_config = {
174+
"methodConfig": [
175+
{
176+
"name": [{"service": ""}],
177+
"retryPolicy": {
178+
"maxAttempts": max_attempts,
179+
"initialBackoff": f"{initial_backoff_ms/1000.0}s",
180+
"maxBackoff": f"{max_backoff_ms/1000.0}s",
181+
"backoffMultiplier": backoff_multiplier,
182+
"retryableStatusCodes": codes,
183+
},
184+
}
185+
]
186+
}
187+
return ("grpc.service_config", json.dumps(service_config))
188+
189+
190+
def build_grpc_channel_options(base_options: Optional[Iterable[tuple[str, Any]]] = None) -> Optional[list[tuple[str, Any]]]:
191+
"""Combine base options + env-driven keepalive and retry service config.
192+
193+
The returned list is safe to pass as the `options` argument to grpc.secure_channel/insecure_channel.
194+
"""
195+
combined: list[tuple[str, Any]] = []
196+
if base_options:
197+
combined.extend(list(base_options))
198+
199+
keepalive = get_grpc_keepalive_options()
200+
if keepalive:
201+
combined.extend(keepalive)
202+
retry_opt = get_grpc_retry_service_config_option()
203+
if retry_opt is not None:
204+
combined.append(retry_opt)
205+
return combined if combined else None
206+
207+
86208
def get_logger(
87209
name_suffix: str,
88210
log_handler: Optional[logging.Handler] = None,

durabletask/task.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,15 @@ class OrchestrationStateError(Exception):
219219
pass
220220

221221

222+
class NonRetryableError(Exception):
223+
"""Exception indicating the operation should not be retried.
224+
225+
If an activity or sub-orchestration raises this exception, retry logic will be
226+
bypassed and the failure will be returned immediately to the orchestrator.
227+
"""
228+
pass
229+
230+
222231
class Task(ABC, Generic[T]):
223232
"""Abstract base class for asynchronous tasks in a durable orchestration."""
224233
_result: T
@@ -447,7 +456,8 @@ def __init__(self, *,
447456
max_number_of_attempts: int,
448457
backoff_coefficient: Optional[float] = 1.0,
449458
max_retry_interval: Optional[timedelta] = None,
450-
retry_timeout: Optional[timedelta] = None):
459+
retry_timeout: Optional[timedelta] = None,
460+
non_retryable_error_types: Optional[list[Union[str, type]]] = None):
451461
"""Creates a new RetryPolicy instance.
452462
453463
Parameters
@@ -462,6 +472,11 @@ def __init__(self, *,
462472
The maximum retry interval to use for any retry attempt.
463473
retry_timeout : Optional[timedelta]
464474
The maximum amount of time to spend retrying the operation.
475+
non_retryable_error_types : Optional[list[Union[str, type]]]
476+
A list of exception type names or classes that should not be retried.
477+
If a failure's error type matches any of these, the task fails immediately.
478+
The built-in NonRetryableError is always treated as non-retryable regardless
479+
of this setting.
465480
"""
466481
# validate inputs
467482
if first_retry_interval < timedelta(seconds=0):
@@ -480,6 +495,17 @@ def __init__(self, *,
480495
self._backoff_coefficient = backoff_coefficient
481496
self._max_retry_interval = max_retry_interval
482497
self._retry_timeout = retry_timeout
498+
# Normalize non-retryable error type names to a set of strings
499+
names: Optional[set[str]] = None
500+
if non_retryable_error_types:
501+
names = set()
502+
for t in non_retryable_error_types:
503+
if isinstance(t, str):
504+
if t:
505+
names.add(t)
506+
elif isinstance(t, type):
507+
names.add(t.__name__)
508+
self._non_retryable_error_types = names
483509

484510
@property
485511
def first_retry_interval(self) -> timedelta:
@@ -506,6 +532,15 @@ def retry_timeout(self) -> Optional[timedelta]:
506532
"""The maximum amount of time to spend retrying the operation."""
507533
return self._retry_timeout
508534

535+
@property
536+
def non_retryable_error_types(self) -> Optional[set[str]]:
537+
"""Set of error type names that should not be retried.
538+
539+
Comparison is performed against the errorType string provided by the
540+
backend (typically the exception class name).
541+
"""
542+
return self._non_retryable_error_types
543+
509544

510545
def get_name(fn: Callable) -> str:
511546
"""Returns the name of the provided function"""

durabletask/worker.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
2121
import durabletask.internal.shared as shared
2222
from durabletask import task
23-
from durabletask.asyncio_compat import (AsyncWorkflowContext,
24-
CoroutineOrchestratorRunner)
23+
from durabletask.asyncio_compat import AsyncWorkflowContext, CoroutineOrchestratorRunner
2524
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
2625

2726
TInput = TypeVar("TInput")
@@ -763,7 +762,17 @@ def resume(self):
763762
else:
764763
# Resume the generator with the previous result.
765764
# This will either return a Task or raise StopIteration if it's done.
766-
next_task = self._generator.send(self._previous_task.get_result())
765+
try:
766+
_val = self._previous_task.get_result()
767+
import os as _os
768+
if _os.getenv('DAPR_WF_DEBUG') or _os.getenv('DT_DEBUG'):
769+
print(f"[DT] resume send instance={self._instance_id} type={type(_val)} is_none={_val is None}")
770+
except Exception as _e:
771+
import os as _os
772+
if _os.getenv('DAPR_WF_DEBUG') or _os.getenv('DT_DEBUG'):
773+
print(f"[DT] resume send error instance={self._instance_id} err={_e}")
774+
raise
775+
next_task = self._generator.send(_val)
767776

768777
if not isinstance(next_task, task.Task):
769778
raise TypeError("The orchestrator generator yielded a non-Task object")
@@ -1047,6 +1056,7 @@ def execute(
10471056
)
10481057
ctx._is_replaying = True
10491058
for old_event in old_events:
1059+
self._logger.debug(f"OLD-EVENT: {instance_id}: {old_event}")
10501060
self.process_event(ctx, old_event)
10511061

10521062
# Get new actions by executing newly received events into the orchestrator function
@@ -1119,6 +1129,13 @@ def process_event(
11191129
result = fn(
11201130
ctx, input
11211131
) # this does not execute the generator, only creates it
1132+
try:
1133+
from types import GeneratorType as _GenT
1134+
import os as _os
1135+
if _os.getenv('DAPR_WF_DEBUG') or _os.getenv('DT_DEBUG'):
1136+
print(f"[DT] executionStarted orchestrator returned type={type(result)} is_gen={isinstance(result, _GenT)} id={id(result) if isinstance(result, _GenT) else 'n/a'}")
1137+
except Exception:
1138+
pass
11221139
if isinstance(result, GeneratorType):
11231140
# Start the orchestrator's generator function
11241141
ctx.run(result)
@@ -1206,6 +1223,13 @@ def process_event(
12061223
result = None
12071224
if not ph.is_empty(event.taskCompleted.result):
12081225
result = shared.from_json(event.taskCompleted.result.value)
1226+
try:
1227+
import os as _os
1228+
if _os.getenv('DAPR_WF_DEBUG') or _os.getenv('DT_DEBUG'):
1229+
print(f"[DT] taskCompleted decode instance={ctx.instance_id} task_id={task_id} type={type(result)} is_none={result is None}")
1230+
print(f"[DT] pending_task_present={activity_task is not None}")
1231+
except Exception:
1232+
pass
12091233
activity_task.complete(result)
12101234
ctx.resume()
12111235
elif event.HasField("taskFailed"):
@@ -1221,16 +1245,32 @@ def process_event(
12211245

12221246
if isinstance(activity_task, task.RetryableTask):
12231247
if activity_task._retry_policy is not None:
1224-
next_delay = activity_task.compute_next_delay()
1225-
if next_delay is None:
1248+
# Check for non-retryable errors by type name
1249+
error_type = event.taskFailed.failureDetails.errorType
1250+
policy = activity_task._retry_policy
1251+
is_non_retryable = False
1252+
if error_type == getattr(task.NonRetryableError, "__name__", "NonRetryableError"):
1253+
is_non_retryable = True
1254+
elif policy.non_retryable_error_types is not None and error_type in policy.non_retryable_error_types:
1255+
is_non_retryable = True
1256+
1257+
if is_non_retryable:
12261258
activity_task.fail(
12271259
f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
12281260
event.taskFailed.failureDetails,
12291261
)
12301262
ctx.resume()
12311263
else:
1232-
activity_task.increment_attempt_count()
1233-
ctx.create_timer_internal(next_delay, activity_task)
1264+
next_delay = activity_task.compute_next_delay()
1265+
if next_delay is None:
1266+
activity_task.fail(
1267+
f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
1268+
event.taskFailed.failureDetails,
1269+
)
1270+
ctx.resume()
1271+
else:
1272+
activity_task.increment_attempt_count()
1273+
ctx.create_timer_internal(next_delay, activity_task)
12341274
elif isinstance(activity_task, task.CompletableTask):
12351275
activity_task.fail(
12361276
f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
@@ -1292,16 +1332,32 @@ def process_event(
12921332
return
12931333
if isinstance(sub_orch_task, task.RetryableTask):
12941334
if sub_orch_task._retry_policy is not None:
1295-
next_delay = sub_orch_task.compute_next_delay()
1296-
if next_delay is None:
1335+
# Check for non-retryable errors by type name
1336+
error_type = failedEvent.failureDetails.errorType
1337+
policy = sub_orch_task._retry_policy
1338+
is_non_retryable = False
1339+
if error_type == getattr(task.NonRetryableError, "__name__", "NonRetryableError"):
1340+
is_non_retryable = True
1341+
elif policy.non_retryable_error_types is not None and error_type in policy.non_retryable_error_types:
1342+
is_non_retryable = True
1343+
1344+
if is_non_retryable:
12971345
sub_orch_task.fail(
12981346
f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
12991347
failedEvent.failureDetails,
13001348
)
13011349
ctx.resume()
13021350
else:
1303-
sub_orch_task.increment_attempt_count()
1304-
ctx.create_timer_internal(next_delay, sub_orch_task)
1351+
next_delay = sub_orch_task.compute_next_delay()
1352+
if next_delay is None:
1353+
sub_orch_task.fail(
1354+
f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
1355+
failedEvent.failureDetails,
1356+
)
1357+
ctx.resume()
1358+
else:
1359+
sub_orch_task.increment_attempt_count()
1360+
ctx.create_timer_internal(next_delay, sub_orch_task)
13051361
elif isinstance(sub_orch_task, task.CompletableTask):
13061362
sub_orch_task.fail(
13071363
f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",

0 commit comments

Comments
 (0)