Skip to content

[DO NOT MERGE] Add debug code #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: nonlocal-gws
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/physics/cam/gw_drag.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1549,7 +1549,7 @@ subroutine gw_tend(state, pbuf, dt, ptend, cam_in, flx_heat)
flx_heat = 0._r8

if ( use_gw_nlgw ) then
call gw_nlgw_dp_ml(state1,ptend)
call gw_nlgw_dp_ml(state1,ptend,lchnk)
end if

if (use_gw_convect_dp) then
Expand Down
86 changes: 70 additions & 16 deletions src/physics/cam/gw_nlgw.F90
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@ module gw_nlgw
!

use gw_utils, only: r8, r4
use ppgrid, only: pver !vertical levels
use ppgrid, only: pcols, pver !vertical levels
use physics_types, only: physics_state, physics_ptend
use spmd_utils, only: mpicom, mstrid=>masterprocid, masterproc, mpi_real8, iam
use cam_abortutils, only: endrun
use cam_logfile, only: iulog
use physconst, only: cappa, pi
use interpolate_data, only: lininterp
use cam_history, only: outfld
use gw_nlgw_debug, only: dump_column_data, check_flux_derivatives, dump_flux_profile

use ftorch

use, intrinsic :: ieee_arithmetic

implicit none

public :: gw_nlgw_dp_ml, gw_nlgw_dp_init, gw_nlgw_dp_finalize
Expand Down Expand Up @@ -98,23 +102,25 @@ module gw_nlgw
9.6630859375e-01_r8, 9.7216796875e-01_r8, 9.7753906250e-01_r8, 9.8242187500e-01_r8, 9.8632812500e-01_r8, 9.9023437500e-01_r8, 9.9365234375e-01_r8, 9.9609375000e-01_r8, &
9.9902343750e-01_r8]



contains

!==========================================================================

subroutine gw_nlgw_dp_ml(state_in, ptend)
subroutine gw_nlgw_dp_ml(state_in, ptend, lchnk)

! inputs
type(physics_state), intent(in) :: state_in
integer, intent(in) :: lchnk

! outputs
type(physics_ptend), intent(inout) :: ptend

!---------------------------Local storage-------------------------------
type(torch_tensor) :: tensor_in(1), tensor_out(1)
integer :: ninputs = 1, noutputs = 1
integer, dimension(2) :: layout = [1 , 2]
integer :: ninputs = 1, noutputs = 1, i
integer, dimension(2) :: layout = [1, 2]
real(r8), parameter :: tendency_threshold= 0.5_r8
real(r8), dimension(:, :), allocatable ::pmid_interp

integer :: device_id

Expand Down Expand Up @@ -159,19 +165,45 @@ subroutine gw_nlgw_dp_ml(state_in, ptend)
call normalise_data()
call construct_input()

if (maxval(abs(net_inputs)) > 20._r8) then
write (*, *) 'Net input values unusually large! maxval=', maxval(abs(net_inputs))
call endrun("Abnormal inputs to NN — possibly out of distribution")
end if

! send all columns from this process
call torch_tensor_from_array(tensor_in(1), net_inputs, layout, torch_kCUDA, device_id)
call torch_tensor_from_array(tensor_out(1), net_outputs, layout, torch_kCPU)

! Run net forward on data
call torch_model_forward(nlgw_model, tensor_in, tensor_out)

if (any(.not. ieee_is_finite(net_outputs))) then
call endrun("Net output contains NaN or Inf")
end if

! Extract and denormalise outputs
call extract_output()
call denormalise_data()

call flux_to_forcing(uflux, utgw)
call flux_to_forcing(vflux, vtgw)
call flux_to_forcing(uflux, utgw, lchnk)
call flux_to_forcing(vflux, vtgw, lchnk)

do i = 1, ncol
pmid_interp(i,:) = era5_ak(:) + ps(i)*era5_bk(:)
if (maxval(abs(utgw(i,:))) > tendency_threshold .or. &
maxval(abs(vtgw(i,:))) > tendency_threshold) then
call dump_column_data(i, net_inputs, net_outputs, utgw, vtgw, pmid, phis, tendency_threshold, lchnk, pver)
call dump_flux_profile(i, net_outputs, pmid_interp, lchnk, pver_interp) ! This is in model (era5) space
end if
end do

! Write UTGW and VTGW to file
call outfld('UTGW_NL',utgw,ncol,lchnk)
call outfld('VTGW_NL',vtgw,ncol,lchnk)

! ! NOTE - clamp the tendencies will fix the bug
! utgw = max(min(utgw, 0.2_r8), -0.2_r8)
! vtgw = max(min(vtgw, 0.2_r8), -0.2_r8)

! update the tendencies
ptend%u(:ncol,:pver) = ptend%u(:ncol,:pver) + utgw(:ncol,:pver)
Expand Down Expand Up @@ -202,9 +234,10 @@ subroutine gw_nlgw_dp_ml(state_in, ptend)

end subroutine gw_nlgw_dp_ml


subroutine gw_nlgw_dp_init(model_path)

use cam_history, only: addfld

character(len=*), intent(in) :: model_path ! Filepath to PyTorch Torchscript net
integer :: device_id

Expand All @@ -219,8 +252,10 @@ subroutine gw_nlgw_dp_init(model_path)
write(iulog,*)'nlgw model loaded from: ', model_path
endif

end subroutine gw_nlgw_dp_init
call addfld('UTGW_NL', (/'lev'/), 'A', 'm/s2', 'Nonlinear GW zonal wind tendency')
call addfld('VTGW_NL', (/'lev'/), 'A', 'm/s2', 'Nonlinear GW meridional wind tendency')

end subroutine gw_nlgw_dp_init

subroutine gw_nlgw_dp_finalize()

Expand All @@ -231,11 +266,10 @@ subroutine gw_nlgw_dp_finalize()

end subroutine gw_nlgw_dp_finalize


subroutine read_norms()

! TODO
! - replace hardcoded means/std devs with netcdf file?
! - replace hardcoded means/std devs withnetcdf file?

lat_mean = 0._r8
lon_mean = 0._r8
Expand Down Expand Up @@ -295,6 +329,17 @@ subroutine construct_input()

do i = 1, ncol
pmid_interp(i,:) = era5_ak(:) + ps(i) * era5_bk(:)

! if (i == 1 .and. masterproc) then
! write (*, *) "ps = ", ps(i)
! write (*, *) "pmid_interp top level (lev=0) = ", pmid_interp(i, 1)
! write (*, *) "pmid_interp bottom level (lev=136) = ", pmid_interp(i, pver_interp)
! end if

if (any(pmid_interp(i,:) < 0._r8)) then
write (*,*) 'Negative pressure in pmid_interp! col=', i, ' min p=', minval(pmid_interp(i, :))
call endrun("Invalid interpolatedpressure")
end if
call lininterp(u(i,:), pmid(i,:), pver, u_interp(i,:), pmid_interp(i,:), pver_interp)
call lininterp(v(i,:), pmid(i,:), pver, v_interp(i,:), pmid_interp(i,:), pver_interp)
call lininterp(theta(i,:), pmid(i,:), pver, theta_interp(i,:), pmid_interp(i,:), pver_interp)
Expand Down Expand Up @@ -341,8 +386,8 @@ subroutine extract_output()
allocate(uflux_interp(ncol,pver_interp))
allocate(vflux_interp(ncol,pver_interp))

uflux_interp(:, :) = net_outputs(:,:pver_interp)
vflux_interp(:, :) = net_outputs(:,pver_interp+1:)
uflux_interp(:, :) = net_outputs(:, 1:pver_interp)
vflux_interp(:, :) = net_outputs(:, pver_interp + 1:2*pver_interp)

do i = 1, ncol
pmid_interp(i,:) = era5_ak(:) + ps(i) * era5_bk(:)
Expand All @@ -361,6 +406,11 @@ subroutine denormalise_data()
uflux = uflux**3._r8 * uflux_std + uflux_mean
vflux = vflux**3._r8 * vflux_std + vflux_mean

if (any(abs(uflux(:, 1)) > 100.0_r8)) then
write (*, *) 'uflux exploding at surface! Example col=', maxloc(abs(uflux(:, 1))), ' value=', maxval(abs(uflux(:, 1)))
call endrun('uflux too large at surface')
end if

end subroutine denormalise_data

elemental function cbrt(a) result(root)
Expand All @@ -370,20 +420,24 @@ elemental function cbrt(a) result(root)
root = sign(abs(a)**one_third, a)
end function cbrt

subroutine flux_to_forcing(flux, forcing)
subroutine flux_to_forcing(flux, forcing, lchnk)

real(r8), intent(in), dimension(:,:) :: flux
integer, intent(in) :: lchnk
real(r8), intent(out), dimension(:,:) :: forcing ! forcing = -d(u'\omega')/d(p), units = m/s^2
real(r8) :: dp

integer :: level, col

! convert fluxes to tendencies
! pressure profile must be in Pascals

call check_flux_derivatives(flux, pmid, lchnk, 0.01_r8)

do col = 1, ncol
forcing(col,1) = -1*(flux(col,2) - flux(col,1))/(pmid(col,2) - pmid(col,1))
do level = 2, pver-1
forcing(col,level) = (flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1))))
forcing(col,level) = -1*(flux(col,level+1) - flux(col,level-1)) / (pmid(col,level)*(log(pmid(col,level+1)) - log(pmid(col,level-1))))
end do
forcing(col,pver) = -1*(flux(col,pver) - flux(col,pver-1)) / (pmid(col,pver) - pmid(col,pver-1))
end do
Expand Down
Loading