Skip to content

Commit 64205cb

Browse files
Fixing dtype issue (#2372)
1 parent 0c04f88 commit 64205cb

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

keras_hub/src/models/deberta_v3/disentangled_self_attention.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,14 @@ def _make_log_bucket_position(self, rel_pos):
217217
)
218218

219219
def _get_log_pos(abs_pos, mid):
220-
numerator = ops.log(abs_pos / mid)
220+
numerator = ops.log(
221+
ops.cast(abs_pos, "float32") / ops.cast(mid, "float32")
222+
)
221223
numerator = numerator * ops.cast(mid - 1, dtype=numerator.dtype)
222-
denominator = ops.log((self.max_position_embeddings - 1) / mid)
224+
denominator = ops.log(
225+
ops.cast(self.max_position_embeddings - 1, "float32")
226+
/ ops.cast(mid, "float32")
227+
)
223228
val = ops.ceil(numerator / denominator)
224229
val = ops.cast(val, dtype=mid.dtype)
225230
val = val + mid

0 commit comments

Comments
 (0)