InĀ [3]:
!pip install geedim
!pip install geemap
Collecting geedim Downloading geedim-2.0.0-py3-none-any.whl.metadata (6.0 kB) Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.12/dist-packages (from geedim) (2.0.2) Requirement already satisfied: rasterio>=1.3.8 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.0) Requirement already satisfied: click>=8 in /usr/local/lib/python3.12/dist-packages (from geedim) (8.3.1) Requirement already satisfied: tqdm>=4.6 in /usr/local/lib/python3.12/dist-packages (from geedim) (4.67.1) Requirement already satisfied: earthengine-api>=0.1.379 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.24) Requirement already satisfied: tabulate>=0.9 in /usr/local/lib/python3.12/dist-packages (from geedim) (0.9.0) Requirement already satisfied: fsspec>=2025.2 in /usr/local/lib/python3.12/dist-packages (from geedim) (2025.3.0) Requirement already satisfied: aiohttp>=3.11 in /usr/local/lib/python3.12/dist-packages (from geedim) (3.13.3) Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (2.6.1) Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.4.0) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (25.4.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.8.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (6.7.0) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (0.4.1) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.22.0) Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (3.7.0) Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.187.0) Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.43.0) Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.3.0) Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.31.0) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.32.4) Requirement already satisfied: affine in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2.4.0) Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2026.1.4) Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (0.7.2) Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (3.3.1) Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.12/dist-packages (from aiosignal>=1.4.0->aiohttp>=3.11->geedim) (4.15.0) Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (2.29.0) Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (4.2.0) Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (6.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.4.2) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (4.9.1) Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.12/dist-packages (from yarl<2.0,>=1.17.0->aiohttp>=3.11->geedim) (3.11) Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.5.0) Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.8.0) Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (1.8.0) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (3.4.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (2.5.0) Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.72.0) Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (5.29.5) Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.27.0) Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.6.1) Downloading geedim-2.0.0-py3-none-any.whl (73 kB) āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā 73.1/73.1 kB 6.9 MB/s eta 0:00:00 Installing collected packages: geedim Successfully installed geedim-2.0.0 Requirement already satisfied: geemap in /usr/local/lib/python3.12/dist-packages (0.35.3) Requirement already satisfied: bqplot in /usr/local/lib/python3.12/dist-packages (from geemap) (0.12.45) Requirement already satisfied: colour in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.5) Requirement already satisfied: earthengine-api>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (1.5.24) Requirement already satisfied: eerepr>=0.1.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.2) Requirement already satisfied: folium>=0.17.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0) Requirement already satisfied: geocoder in /usr/local/lib/python3.12/dist-packages (from geemap) (1.38.1) Requirement already satisfied: ipyevents in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.4) Requirement already satisfied: ipyfilechooser>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.6.0) Requirement already satisfied: ipyleaflet>=0.19.2 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0) Requirement already satisfied: ipytree in /usr/local/lib/python3.12/dist-packages (from geemap) (0.2.2) Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from geemap) (3.10.0) Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.2) Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from geemap) (2.2.2) Requirement already satisfied: plotly in /usr/local/lib/python3.12/dist-packages (from geemap) (5.24.1) Requirement already satisfied: pyperclip in /usr/local/lib/python3.12/dist-packages (from geemap) (1.11.0) Requirement already satisfied: pyshp>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from geemap) (3.0.3) Requirement already satisfied: python-box in /usr/local/lib/python3.12/dist-packages (from geemap) (7.3.2) Requirement already satisfied: scooby in /usr/local/lib/python3.12/dist-packages (from geemap) (0.11.0) Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (3.7.0) Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.187.0) Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.43.0) Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.3.0) Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.31.0) Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.32.4) Requirement already satisfied: branca>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (0.8.2) Requirement already satisfied: jinja2>=2.9 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (3.1.6) Requirement already satisfied: xyzservices in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (2025.11.0) Requirement already satisfied: ipywidgets in /usr/local/lib/python3.12/dist-packages (from ipyfilechooser>=0.6.0->geemap) (7.7.1) Requirement already satisfied: jupyter-leaflet<0.21,>=0.20 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.20.0) Requirement already satisfied: traittypes<3,>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.2.3) Requirement already satisfied: traitlets>=4.3.0 in /usr/local/lib/python3.12/dist-packages (from bqplot->geemap) (5.7.1) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.3) Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (8.3.1) Requirement already satisfied: future in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.0.0) Requirement already satisfied: ratelim in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (0.1.6) Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.17.0) Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.3.3) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (4.61.1) Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.4.9) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (25.0) Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (11.3.0) Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (3.3.1) Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.12/dist-packages (from plotly->geemap) (9.1.2) Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (2.29.0) Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (4.2.0) Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (6.2.4) Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.4.2) Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (4.9.1) Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.17.1) Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0) Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.6.10) Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.34.0) Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.16) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2>=2.9->folium>=0.17.0->geemap) (3.0.3) Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.5.0) Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.8.0) Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (1.8.0) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.4.4) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.11) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2026.1.4) Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from ratelim->geocoder->geemap) (4.4.2) Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.72.0) Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (5.29.5) Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.27.0) Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.8.15) Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.4.9) Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.1) Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.6.0) Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.5) Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (26.2.1) Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.1) Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (75.2.0) Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB) Requirement already satisfied: pickleshare in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.5) Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.52) Requirement already satisfied: pygments in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.19.2) Requirement already satisfied: backcall in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0) Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.9.0) Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.6.1) Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.12/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.7) Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.12/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.8.5) Requirement already satisfied: entrypoints in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.4) Requirement already satisfied: jupyter-core>=4.9.2 in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.1) Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0) Requirement already satisfied: nbformat in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.10.4) Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.16.6) Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0) Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.18.1) Requirement already satisfied: prometheus-client in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.23.1) Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.3) Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.12/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.14) Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.12/dist-packages (from jupyter-core>=4.9.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.5.1) Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.12/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.4) Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.13.5) Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.3.0) Requirement already satisfied: defusedxml in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.1) Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.3.0) Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.2.0) Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.10.4) Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1) Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.21.2) Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.26.0) Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.12/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0) Requirement already satisfied: webencodings in /usr/local/lib/python3.12/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.1) Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0) Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.4.0) Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2025.9.1) Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.37.0) Requirement already satisfied: rpds-py>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.30.0) Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.12/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.14.0) Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0) Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.8.1) Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.15.0) Requirement already satisfied: pycparser in /usr/local/lib/python3.12/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.23) Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.12.1) Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.12.0) Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.3) Requirement already satisfied: overrides>=5.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.7.0) Requirement already satisfied: websocket-client>=1.7 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.9.0) Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.0.0) Requirement already satisfied: pyyaml>=5.3 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.0.3) Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.4) Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.1) Requirement already satisfied: fqdn in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1) Requirement already satisfied: isoduration in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (20.11.0) Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.0) Requirement already satisfied: rfc3987-syntax>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.1.0) Requirement already satisfied: uri-template in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.0) Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.10.0) Requirement already satisfied: lark>=1.2.2 in /usr/local/lib/python3.12/dist-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.1) Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.12/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0) Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB) āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā 1.6/1.6 MB 74.6 MB/s eta 0:00:00 Installing collected packages: jedi Successfully installed jedi-0.19.2
InĀ [4]:
# CELL 1: HYPER-STACK DATA LOADING
import os
import numpy as np
import rasterio
from rasterio.windows import from_bounds
import ee
import geemap
import torch
import cv2
import shutil
import time
from google.colab import drive
# 1. SETUP & AUTHENTICATION (Robust Fix)
drive.mount('/content/drive', force_remount=True)
def initialize_ee():
try:
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Earth Engine Initialized Successfully.")
except Exception as e:
print(f" Initialization failed. Triggering Authentication... ({e})")
ee.Authenticate()
ee.Initialize(project='[REDACTED_FOR_SECURITY]')
print(" Earth Engine Authenticated & Initialized.")
initialize_ee()
# CONFIG
TIME_WINDOWS = [
('2024-10-15', '2024-11-15'), # 1. Pre-Sowing
('2024-11-16', '2024-12-15'), # 2. Early Growth
('2024-12-16', '2025-01-15'), # 3. Late Growth
('2025-01-16', '2025-02-15'), # 4. Peak Greenness
('2025-02-16', '2025-03-15'), # 5. Flowering
('2025-03-16', '2025-04-15') # 6. Harvest
]
ASSET_ID = 'projects/[REDACTED_FOR_SECURITY]/assets/Punjab_Mask_2024_NEW'
PATCH_SIZE = 224
def get_hyper_satmae_data():
if not ee.data._cloud_api_resource:
print(" EE lost connection. Re-initializing...")
initialize_ee()
print(" Starting Hyper-Stack Ingestion...")
mask_img = ee.Image(ASSET_ID)
roi_geom = mask_img.geometry()
# Download Mask
mask_file = 'local_mask_hyper.tif'
if not os.path.exists(mask_file):
geemap.download_ee_image(mask_img, mask_file, region=roi_geom, scale=10, crs='EPSG:4326', overwrite=True)
with rasterio.open(mask_file) as src:
b = src.bounds
cx, cy = (b.left + b.right)/2, (b.bottom + b.top)/2
offset = 0.06
window = from_bounds(cx-offset, cy-offset, cx+offset, cy+offset, src.transform)
mask = src.read(1, window=window)
mask = np.where(mask > 0, 1.0, 0.0).astype(np.float32)
target_h, target_w = mask.shape
small_roi = ee.Geometry.Rectangle([cx-offset, cy-offset, cx+offset, cy+offset], proj=str(src.crs), geodesic=False)
stack = []
for i, (start, end) in enumerate(TIME_WINDOWS):
fname = f'hyper_time_{i}.tif'
# RETRY LOGIC for robustness
attempts = 0
while not os.path.exists(fname) and attempts < 3:
try:
print(f" Downloading Step {i+1}/{len(TIME_WINDOWS)} (Attempt {attempts+1})...")
s2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED').filterBounds(small_roi).filterDate(start, end).median().select(['B2','B3','B4','B8','B11','B12'])
s1 = ee.ImageCollection('COPERNICUS/S1_GRD').filterBounds(small_roi).filterDate(start, end).mean().select(['VV','VH'])
fused = ee.Image.cat([s2, s1]).clip(small_roi)
geemap.download_ee_image(fused, fname, region=small_roi, scale=10, crs='EPSG:4326', overwrite=True)
except Exception as e:
print(f" Error: {e}")
attempts += 1
time.sleep(2)
if not os.path.exists(fname):
print(f" Failed to download {start}. Checking fallback...")
if i > 0:
print(" Copying previous month's data.")
shutil.copy(f'hyper_time_{i-1}.tif', fname)
else:
raise RuntimeError(" CRITICAL: First time step failed. Cannot proceed.")
# LOAD & PROCESS
with rasterio.open(fname) as src:
arr = src.read() # (8, H, W)
arr = np.transpose(arr, (1, 2, 0)) # (H, W, 8)
if arr.shape[:2] != (target_h, target_w):
arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# Normalization
s2_bands = np.clip(arr[:,:,:6] / 5000.0, 0, 1)
s1_bands = np.clip((arr[:,:,6:] - (-25.0)) / 25.0, 0, 1)
# Indices
nir = s2_bands[:, :, 3] + 1e-6
red = s2_bands[:, :, 2] + 1e-6
blue = s2_bands[:, :, 0] + 1e-6
ndvi = ((nir - red) / (nir + red) + 1) / 2.0
evi = np.clip(2.5 * ((nir - red) / (nir + 6 * red - 7.5 * blue + 1)), 0, 1)
combined = np.concatenate([s2_bands, s1_bands, ndvi[:, :, None], evi[:, :, None]], axis=2)
stack.append(combined)
full_cube = np.stack(stack, axis=2)
x_out, y_out = [], []
stride = PATCH_SIZE
print(" Creating Patches...")
for y in range(0, target_h, stride):
for x in range(0, target_w, stride):
img_p = full_cube[y:y+stride, x:x+stride]
mask_p = mask[y:y+stride, x:x+stride]
if img_p.shape[0] != PATCH_SIZE or img_p.shape[1] != PATCH_SIZE: continue
if np.isnan(img_p).any(): continue
x_out.append(img_p)
y_out.append(mask_p)
X = np.array(x_out, dtype=np.float32).transpose(0, 4, 3, 1, 2) # (B, C, T, H, W)
y = np.array(y_out, dtype=np.float32)[:, None, :, :]
X = np.nan_to_num(X, nan=0.0)
print(f" Hyper-Dataset Ready. Shape: {X.shape}")
return torch.tensor(X), torch.tensor(y)
# CALL THE FUNCTION
X_data, y_data = get_hyper_satmae_data()
Mounted at /content/drive Earth Engine Initialized Successfully. Starting Hyper-Stack Ingestion...
/usr/local/lib/python3.12/dist-packages/geemap/common.py:12471: FutureWarning: 'BaseImage' is deprecated and will be removed in a future release. Please use the 'ee.Image.gd' accessor instead. img = gd.download.BaseImage(image)
...tmae-2026/assets/Punjab_Mask_2024_NEW: 0%| |0/585 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'projects/satmae-2026/assets/Punjab_Mask_2024_NEW'. return STACClient().get(self.id)
Downloading Step 1/6 (Attempt 1)...
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 /usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'None'. return STACClient().get(self.id)
Downloading Step 2/6 (Attempt 1)...
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading Step 3/6 (Attempt 1)...
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading Step 4/6 (Attempt 1)...
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading Step 5/6 (Attempt 1)...
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Downloading Step 6/6 (Attempt 1)...
0%| |0/48 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10 WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
Creating Patches... Hyper-Dataset Ready. Shape: (25, 10, 6, 224, 224)
InĀ [5]:
# CELL 2: MODEL DEFINITION (PARTIAL FINE-TUNE: 10 Frozen / 2 Unfrozen)
import torch.nn as nn
from huggingface_hub import hf_hub_download
class SatMAEPatchEmbed(nn.Module):
def __init__(self, in_chans=10, embed_dim=768, patch_size=16):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, T, H, W = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(B * T, C, H, W)
x = self.proj(x).flatten(2).transpose(1, 2)
x = x.reshape(B, T, -1, x.shape[-1])
return x
class SatMAEBackbone(nn.Module):
def __init__(self, num_frames=6, in_chans=10, embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.patch_embed = SatMAEPatchEmbed(in_chans=in_chans, embed_dim=embed_dim)
num_patches = (224 // 16) ** 2
self.pos_embed = nn.Parameter(torch.zeros(1, 1, num_patches + 1, embed_dim))
self.time_embed = nn.Parameter(torch.zeros(1, num_frames, 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4, activation="gelu", batch_first=True, norm_first=True)
self.blocks = nn.TransformerEncoder(encoder_layer, num_layers=depth)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.patch_embed(x)
B, T, N, D = x.shape
x = x + self.time_embed
x = x.reshape(B, T*N, D)
spatial_pos = self.pos_embed[:, :, 1:, :].expand(B, T, -1, -1).reshape(B, T*N, D)
x = x + spatial_pos
cls_token = self.cls_token.expand(B, -1, -1, -1).reshape(B, 1, D) + self.pos_embed[:, :, 0, :].expand(B, 1, D)
x = torch.cat((cls_token, x), dim=1)
x = self.blocks(x)
x = self.norm(x)
return x
class SatMAESegmentation(nn.Module):
def __init__(self, num_frames=6, embed_dim=768):
super().__init__()
print(f" Constructing SatMAE Hyper-Stack ({num_frames} Time Steps)...")
self.num_frames = num_frames
self.backbone = SatMAEBackbone(num_frames=num_frames, in_chans=10, embed_dim=embed_dim)
try:
print(" Adapting Google ViT Weights to 10 Channels...")
p = hf_hub_download("google/vit-base-patch16-224", "pytorch_model.bin")
sd = torch.load(p, map_location='cpu')
w = sd['vit.embeddings.patch_embeddings.projection.weight']
new_w = torch.zeros(768, 10, 16, 16)
new_w[:, :3] = w
new_w[:, 3:] = w.mean(dim=1, keepdim=True).repeat(1, 7, 1, 1)
self.backbone.patch_embed.proj.weight.data = new_w
self.backbone.patch_embed.proj.bias.data = sd['vit.embeddings.patch_embeddings.projection.bias']
print(" Weights Adapted.")
except:
print(" Weights missing, using random init.")
# --- MODIFIED FREEZING STRATEGY (PARTIAL FINE-TUNE) ---
print(" Applying Partial Fine-Tuning Strategy...")
# 1. Freeze EVERYTHING first (Blocks 1-12, Embeddings, Norms)
for param in self.backbone.parameters():
param.requires_grad = False
# 2. Unfreeze Adapters (Input + Time)
self.backbone.time_embed.requires_grad = True
self.backbone.patch_embed.proj.weight.requires_grad = True
# 3. Unfreeze LAST 2 Transformer Blocks (Deep Tuning)
for layer in self.backbone.blocks.layers[-2:]:
for param in layer.parameters():
param.requires_grad = True
# 4. Unfreeze Norm Layer
for param in self.backbone.norm.parameters():
param.requires_grad = True
print(" Config: Blocks 1-10 Frozen | Blocks 11-12 Unfrozen")
# DECODER
self.temporal_agg = nn.Conv2d(embed_dim * num_frames, embed_dim, kernel_size=1)
self.decoder = nn.Sequential(
nn.Upsample(scale_factor=2), nn.Conv2d(embed_dim, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.GELU(),
nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.GELU(),
nn.Upsample(scale_factor=2), nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.GELU(),
nn.Upsample(scale_factor=2), nn.Conv2d(64, 32, 3, 1, 1), nn.BatchNorm2d(32), nn.GELU(),
nn.Conv2d(32, 1, 1)
)
def forward(self, x):
features = self.backbone(x)[:, 1:, :]
B, L, D = features.shape
features = features.view(B, self.num_frames, 14, 14, D).permute(0, 4, 1, 2, 3).flatten(1, 2)
features = self.temporal_agg(features)
return self.decoder(features)
InĀ [8]:
# CELL 3: TRAINING
import torch
import torch.optim as optim
import torch.optim.swa_utils as swa_utils
from torch.utils.data import TensorDataset, DataLoader, random_split
from scipy.ndimage import distance_transform_edt as distance
import os
import numpy as np # Added numpy import just in case
# --- CONFIGURATION ---
# 1. DEFINE SAVE_DIR FIRST (The Fix)
SAVE_DIR = '/content/drive/MyDrive/SatMAE_LongTrain_500/'
if not os.path.exists(SAVE_DIR):
os.makedirs(SAVE_DIR)
# 2. NOW DEFINE CHECKPOINT PATH
CHECKPOINT_PATH = SAVE_DIR + "checkpoint_hyper_partial_finetunning.pth"
RESUME = False
BATCH_SIZE = 4
EPOCHS = 500
# ... rest of the code follows ...
CHECKPOINT_PATH = SAVE_DIR + "checkpoint_hyper_partial_finetunning.pth"
# 1. TIME-SAFE AUGMENTATION
def apply_augmentation(x, y):
if np.random.rand() > 0.5:
x = torch.flip(x, [4]); y = torch.flip(y, [3])
if np.random.rand() > 0.5:
x = torch.flip(x, [3]); y = torch.flip(y, [2])
k = np.random.randint(0, 4)
x = torch.rot90(x, k, [3, 4]); y = torch.rot90(y, k, [2, 3])
return x, y
# 2. LOSS FUNCTIONS (UNCHANGED)
class DiceLoss(nn.Module):
def __init__(self, smooth=1e-6):
super().__init__(); self.smooth = smooth
def forward(self, inputs, targets):
inputs = torch.sigmoid(inputs).reshape(-1)
targets = targets.reshape(-1)
inter = (inputs * targets).sum()
return 1 - (2. * inter + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)
class HausdorffDTLoss(nn.Module):
def __init__(self, alpha=2.0):
super().__init__(); self.alpha = alpha
def forward(self, pred, gt):
with torch.no_grad():
gt_np = gt.cpu().numpy()
dist_map = np.zeros_like(gt_np)
for i in range(len(gt_np)):
mask = (gt_np[i, 0] > 0.5).astype(np.uint8)
if mask.sum() == 0: continue
d_in = distance(mask); d_out = distance(1 - mask)
dist_map[i, 0] = (d_out - d_in)
dist_map = torch.tensor(dist_map, device=pred.device, dtype=torch.float32)
probs = torch.sigmoid(pred)
return torch.mean((probs - gt) ** 2 * (1 + self.alpha * torch.abs(dist_map)))
class CompoundLoss(nn.Module):
def __init__(self):
super().__init__(); self.dice = DiceLoss(); self.boundary = HausdorffDTLoss(alpha=2.0)
def forward(self, p, t): return 0.7*self.dice(p, t) + 0.3*self.boundary(p, t)
# 3. SETUP MODEL & OPTIMIZER
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f" Training on: {device}")
model = SatMAESegmentation(num_frames=6, embed_dim=768).to(device)
criterion = CompoundLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)
# SWA Config
swa_model = swa_utils.AveragedModel(model)
swa_start = 350
swa_scheduler = swa_utils.SWALR(optimizer, swa_lr=5e-5)
# Loaders
ds = TensorDataset(X_data, y_data)
tr_sz = int(0.85 * len(ds))
t_ds, v_ds = random_split(ds, [tr_sz, len(ds)-tr_sz])
train_loader = DataLoader(t_ds, BATCH_SIZE, shuffle=True)
val_loader = DataLoader(v_ds, BATCH_SIZE, shuffle=False)
# 4. RESUME LOGIC
start_epoch = 0
history = {'train_loss': [], 'val_loss': []}
best_loss = float('inf')
if RESUME and os.path.exists(CHECKPOINT_PATH):
print(" Found Checkpoint. Resuming...")
ckpt = torch.load(CHECKPOINT_PATH)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
start_epoch = ckpt['epoch'] + 1
history = ckpt['history']
best_loss = ckpt['best_loss']
elif not RESUME and os.path.exists(CHECKPOINT_PATH):
print(" Force Restart: Deleting old checkpoint...")
os.remove(CHECKPOINT_PATH)
print(" Starting Fresh from Epoch 0.")
else:
print(" No checkpoint found. Starting Fresh.")
# 5. TRAINING LOOP
print(f" Starting Hyper-Stack Partial Fine-Tuning ({EPOCHS} Epochs)...")
for ep in range(start_epoch, EPOCHS):
model.train()
train_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
x, y = apply_augmentation(x, y)
optimizer.zero_grad()
preds = model(x)
loss = criterion(preds, y)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation
model.eval()
val_loss = 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(device), y.to(device)
preds = model(x)
val_loss += criterion(preds, y).item()
avg_t = train_loss / len(train_loader)
avg_v = val_loss / len(val_loader)
history['train_loss'].append(avg_t)
history['val_loss'].append(avg_v)
if ep >= swa_start:
swa_model.update_parameters(model)
swa_scheduler.step()
lr_stat = f"SWA-LR: {swa_scheduler.get_last_lr()[0]:.1e}"
else:
scheduler.step()
lr_stat = f"LR: {scheduler.get_last_lr()[0]:.1e}"
if avg_v < best_loss:
best_loss = avg_v
torch.save(model.state_dict(), SAVE_DIR + "SatMAE_Hyper_Partial_Best.pth")
print(f" Best Model Updated (Loss: {best_loss:.4f})")
if (ep+1) % 5 == 0:
print(f"Ep {ep+1} | T: {avg_t:.4f} | V: {avg_v:.4f} | {lr_stat}")
if (ep+1) % 5 == 0:
torch.save({
'epoch': ep,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'history': history,
'best_loss': best_loss
}, CHECKPOINT_PATH)
print(" Checkpoint Saved.")
print(" Finalizing SWA Model...")
swa_utils.update_bn(train_loader, swa_model, device=device)
torch.save(swa_model.state_dict(), SAVE_DIR + "SatMAE_Hyper_Partial_SWA.pth")
print("Done.")
Training on: cuda Constructing SatMAE Hyper-Stack (6 Time Steps)... Adapting Google ViT Weights to 10 Channels... Weights Adapted. Applying Partial Fine-Tuning Strategy... Config: Blocks 1-10 Frozen | Blocks 11-12 Unfrozen No checkpoint found. Starting Fresh. Starting Hyper-Stack Partial Fine-Tuning (500 Epochs)... Best Model Updated (Loss: 1.1799) Best Model Updated (Loss: 1.1194) Best Model Updated (Loss: 0.9007) Best Model Updated (Loss: 0.7571) Best Model Updated (Loss: 0.6437) Ep 5 | T: 0.5609 | V: 0.6437 | LR: 9.8e-05 Checkpoint Saved. Best Model Updated (Loss: 0.6353) Best Model Updated (Loss: 0.5990) Best Model Updated (Loss: 0.5318) Best Model Updated (Loss: 0.4144) Ep 10 | T: 0.4611 | V: 0.4144 | LR: 9.1e-05 Checkpoint Saved. Best Model Updated (Loss: 0.3289) Ep 15 | T: 0.4386 | V: 0.4985 | LR: 8.0e-05 Checkpoint Saved. Ep 20 | T: 0.4073 | V: 0.4786 | LR: 6.6e-05 Checkpoint Saved. Ep 25 | T: 0.3988 | V: 0.3540 | LR: 5.1e-05 Checkpoint Saved. Ep 30 | T: 0.4121 | V: 0.5108 | LR: 3.5e-05 Checkpoint Saved. Ep 35 | T: 0.3939 | V: 0.3788 | LR: 2.1e-05 Checkpoint Saved. Ep 40 | T: 0.3789 | V: 0.3779 | LR: 1.0e-05 Checkpoint Saved. Ep 45 | T: 0.4246 | V: 0.3868 | LR: 3.4e-06 Checkpoint Saved. Ep 50 | T: 0.3889 | V: 0.3689 | LR: 1.0e-04 Checkpoint Saved. Best Model Updated (Loss: 0.3285) Ep 55 | T: 0.3689 | V: 0.3285 | LR: 9.9e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2839) Ep 60 | T: 0.3751 | V: 0.4794 | LR: 9.8e-05 Checkpoint Saved. Ep 65 | T: 0.3558 | V: 0.3094 | LR: 9.5e-05 Checkpoint Saved. Ep 70 | T: 0.3452 | V: 0.2978 | LR: 9.1e-05 Checkpoint Saved. Ep 75 | T: 0.3341 | V: 0.3550 | LR: 8.6e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2738) Ep 80 | T: 0.3221 | V: 0.3527 | LR: 8.0e-05 Checkpoint Saved. Ep 85 | T: 0.3124 | V: 0.2834 | LR: 7.3e-05 Checkpoint Saved. Ep 90 | T: 0.3013 | V: 0.3831 | LR: 6.6e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2709) Ep 95 | T: 0.3360 | V: 0.3079 | LR: 5.8e-05 Checkpoint Saved. Ep 100 | T: 0.2965 | V: 0.2945 | LR: 5.1e-05 Checkpoint Saved. Ep 105 | T: 0.2864 | V: 0.2798 | LR: 4.3e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2650) Ep 110 | T: 0.2940 | V: 0.3116 | LR: 3.5e-05 Checkpoint Saved. Ep 115 | T: 0.2789 | V: 0.3129 | LR: 2.8e-05 Checkpoint Saved. Ep 120 | T: 0.2812 | V: 0.2790 | LR: 2.1e-05 Checkpoint Saved. Ep 125 | T: 0.2992 | V: 0.2826 | LR: 1.5e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2634) Ep 130 | T: 0.2686 | V: 0.2634 | LR: 1.0e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2624) Best Model Updated (Loss: 0.2602) Best Model Updated (Loss: 0.2561) Ep 135 | T: 0.2757 | V: 0.2637 | LR: 6.4e-06 Checkpoint Saved. Ep 140 | T: 0.2695 | V: 0.2823 | LR: 3.4e-06 Checkpoint Saved. Ep 145 | T: 0.2683 | V: 0.2778 | LR: 1.6e-06 Checkpoint Saved. Ep 150 | T: 0.2710 | V: 0.2717 | LR: 1.0e-04 Checkpoint Saved. Ep 155 | T: 0.2708 | V: 0.2663 | LR: 1.0e-04 Checkpoint Saved. Best Model Updated (Loss: 0.2512) Ep 160 | T: 0.2679 | V: 0.3821 | LR: 9.9e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2484) Best Model Updated (Loss: 0.2343) Ep 165 | T: 0.2648 | V: 0.2846 | LR: 9.9e-05 Checkpoint Saved. Ep 170 | T: 0.2483 | V: 0.2678 | LR: 9.8e-05 Checkpoint Saved. Ep 175 | T: 0.2403 | V: 0.2849 | LR: 9.6e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2255) Ep 180 | T: 0.2433 | V: 0.2816 | LR: 9.5e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2238) Ep 185 | T: 0.2507 | V: 0.2372 | LR: 9.3e-05 Checkpoint Saved. Ep 190 | T: 0.2154 | V: 0.2418 | LR: 9.1e-05 Checkpoint Saved. Ep 195 | T: 0.2233 | V: 0.2375 | LR: 8.8e-05 Checkpoint Saved. Ep 200 | T: 0.2094 | V: 0.2827 | LR: 8.6e-05 Checkpoint Saved. Ep 205 | T: 0.2244 | V: 0.2774 | LR: 8.3e-05 Checkpoint Saved. Ep 210 | T: 0.2076 | V: 0.2498 | LR: 8.0e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2224) Best Model Updated (Loss: 0.2199) Ep 215 | T: 0.2045 | V: 0.2361 | LR: 7.6e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2071) Ep 220 | T: 0.1906 | V: 0.2254 | LR: 7.3e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2070) Best Model Updated (Loss: 0.2024) Ep 225 | T: 0.1965 | V: 0.2024 | LR: 6.9e-05 Checkpoint Saved. Ep 230 | T: 0.2059 | V: 0.2101 | LR: 6.6e-05 Checkpoint Saved. Ep 235 | T: 0.1799 | V: 0.2168 | LR: 6.2e-05 Checkpoint Saved. Best Model Updated (Loss: 0.2003) Best Model Updated (Loss: 0.1961) Ep 240 | T: 0.1787 | V: 0.2143 | LR: 5.8e-05 Checkpoint Saved. Ep 245 | T: 0.1864 | V: 0.2065 | LR: 5.4e-05 Checkpoint Saved. Ep 250 | T: 0.1760 | V: 0.2076 | LR: 5.1e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1951) Ep 255 | T: 0.1771 | V: 0.2110 | LR: 4.7e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1948) Ep 260 | T: 0.1735 | V: 0.1948 | LR: 4.3e-05 Checkpoint Saved. Ep 265 | T: 0.1696 | V: 0.2056 | LR: 3.9e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1948) Ep 270 | T: 0.1613 | V: 0.1985 | LR: 3.5e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1903) Ep 275 | T: 0.1676 | V: 0.2012 | LR: 3.2e-05 Checkpoint Saved. Ep 280 | T: 0.1616 | V: 0.1928 | LR: 2.8e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1899) Ep 285 | T: 0.1603 | V: 0.1946 | LR: 2.5e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1884) Ep 290 | T: 0.1610 | V: 0.1975 | LR: 2.1e-05 Checkpoint Saved. Ep 295 | T: 0.1626 | V: 0.1953 | LR: 1.8e-05 Checkpoint Saved. Ep 300 | T: 0.1598 | V: 0.1905 | LR: 1.5e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1868) Ep 305 | T: 0.1563 | V: 0.1946 | LR: 1.3e-05 Checkpoint Saved. Ep 310 | T: 0.1552 | V: 0.1893 | LR: 1.0e-05 Checkpoint Saved. Best Model Updated (Loss: 0.1857) Ep 315 | T: 0.1576 | V: 0.1872 | LR: 8.3e-06 Checkpoint Saved. Ep 320 | T: 0.1590 | V: 0.1927 | LR: 6.4e-06 Checkpoint Saved. Best Model Updated (Loss: 0.1843) Ep 325 | T: 0.1536 | V: 0.1891 | LR: 4.8e-06 Checkpoint Saved. Ep 330 | T: 0.1527 | V: 0.1940 | LR: 3.4e-06 Checkpoint Saved. Ep 335 | T: 0.1521 | V: 0.1949 | LR: 2.4e-06 Checkpoint Saved. Ep 340 | T: 0.1523 | V: 0.1914 | LR: 1.6e-06 Checkpoint Saved. Ep 345 | T: 0.1546 | V: 0.1904 | LR: 1.2e-06 Checkpoint Saved. Ep 350 | T: 0.1568 | V: 0.1887 | LR: 1.0e-04 Checkpoint Saved. Ep 355 | T: 0.1702 | V: 0.2530 | SWA-LR: 7.5e-05 Checkpoint Saved. Ep 360 | T: 0.1566 | V: 0.2320 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 365 | T: 0.1516 | V: 0.2116 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 370 | T: 0.1604 | V: 0.1910 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 375 | T: 0.1468 | V: 0.1853 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 380 | T: 0.1480 | V: 0.1861 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 385 | T: 0.1445 | V: 0.1832 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 390 | T: 0.1426 | V: 0.1835 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 395 | T: 0.1486 | V: 0.1837 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 400 | T: 0.1393 | V: 0.1866 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 405 | T: 0.1360 | V: 0.1813 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 410 | T: 0.1357 | V: 0.1883 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 415 | T: 0.1336 | V: 0.1802 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 420 | T: 0.1292 | V: 0.1731 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 425 | T: 0.1319 | V: 0.1898 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 430 | T: 0.1339 | V: 0.1920 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 435 | T: 0.1263 | V: 0.1877 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 440 | T: 0.1288 | V: 0.1824 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 445 | T: 0.1235 | V: 0.1784 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 450 | T: 0.1223 | V: 0.1769 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 455 | T: 0.1209 | V: 0.1756 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 460 | T: 0.1194 | V: 0.1864 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 465 | T: 0.1168 | V: 0.1765 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 470 | T: 0.1186 | V: 0.1885 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 475 | T: 0.1138 | V: 0.1765 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 480 | T: 0.1149 | V: 0.1722 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 485 | T: 0.1144 | V: 0.1697 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 490 | T: 0.1138 | V: 0.1709 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 495 | T: 0.1125 | V: 0.1669 | SWA-LR: 5.0e-05 Checkpoint Saved. Ep 500 | T: 0.1106 | V: 0.1775 | SWA-LR: 5.0e-05 Checkpoint Saved. Finalizing SWA Model... Done.
InĀ [9]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
# --- CONFIGURATION ---
CHECKPOINT_PATH = SAVE_DIR + "checkpoint_hyper_partial_finetunning.pth"
# 1. FUNCTION TO PLOT LOSS CURVES
def plot_training_curves(checkpoint_path):
print(f" Loading History from: {checkpoint_path}")
if not os.path.exists(checkpoint_path):
print(" Checkpoint not found. Cannot plot history.")
return
# Load checkpoint
ckpt = torch.load(checkpoint_path, map_location='cpu')
history = ckpt.get('history', None)
if history is None:
print(" No history found in checkpoint.")
return
train_loss = history['train_loss']
val_loss = history['val_loss']
epochs = range(1, len(train_loss) + 1)
# Plotting
plt.figure(figsize=(10, 5))
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')
plt.title('Training vs. Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss (Compound)')
plt.legend()
plt.grid(True)
plt.show()
# 2. FUNCTION TO VISUALIZE SAMPLES (Input vs Truth vs Pred)
def visualize_predictions(model, loader, device, num_samples=3):
print(" Generating Sample Predictions...")
model.eval()
# Get one batch
x, y = next(iter(loader))
x, y = x.to(device), y.to(device)
# Run Inference
with torch.no_grad():
preds = model(x)
probs = torch.sigmoid(preds)
pred_masks = (probs > 0.5).float().cpu().numpy()
# Prepare data for plotting
# x shape: (Batch, Channels, Time, H, W) -> We want (Batch, Channels, H, W) for T=3
x_np = x.cpu().numpy()
y_np = y.cpu().numpy()
# Create Plot Grid
fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
plt.suptitle("Model Predictions: Input vs. Ground Truth vs. Prediction", fontsize=16)
for i in range(num_samples):
# A. False Color Composite (NIR, Red, Green)
# Assuming Time Step 3 is "Peak Greenness" or similar. Adjust index if needed.
t_idx = min(3, x_np.shape[2] - 1)
# Channels: 3=NIR, 2=Red, 1=Green (Indices might vary based on your specific stack order)
# If your stack is [B2, B3, B4, B8...], then: 0=Blue, 1=Green, 2=Red, 3=NIR.
nir = x_np[i, 3, t_idx, :, :]
red = x_np[i, 2, t_idx, :, :]
grn = x_np[i, 1, t_idx, :, :]
# Stack and Normalize
img = np.stack([nir, red, grn], axis=2)
img = (img - img.min()) / (img.max() - img.min() + 1e-6)
# B. Ground Truth
gt = y_np[i, 0, :, :]
# C. Prediction
pred = pred_masks[i, 0, :, :]
# Plot 1: Satellite Image
axs[i, 0].imshow(img)
axs[i, 0].set_title("Satellite Input (False Color)")
axs[i, 0].axis('off')
# Plot 2: Ground Truth
axs[i, 1].imshow(gt, cmap='gray')
axs[i, 1].set_title("Ground Truth (Target)")
axs[i, 1].axis('off')
# Plot 3: Prediction
axs[i, 2].imshow(pred, cmap='gray')
axs[i, 2].set_title("Model Prediction")
axs[i, 2].axis('off')
plt.tight_layout()
plt.show()
# --- EXECUTE BOTH ---
# 1. Plot the Graph
plot_training_curves(CHECKPOINT_PATH)
# 2. Visualize the Images (Load model first)
# Re-initialize architecture to match checkpoint
vis_model = SatMAESegmentation(num_frames=6, embed_dim=768).to(device)
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)
vis_model.load_state_dict(ckpt['model_state_dict'])
visualize_predictions(vis_model, val_loader, device, num_samples=3)
š Loading History from: /content/drive/MyDrive/SatMAE_LongTrain_500/checkpoint_hyper_partial_finetunning.pth
Constructing SatMAE Hyper-Stack (6 Time Steps)... Adapting Google ViT Weights to 10 Channels...
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True warnings.warn(
Weights Adapted. Applying Partial Fine-Tuning Strategy... Config: Blocks 1-10 Frozen | Blocks 11-12 Unfrozen š¼ļø Generating Sample Predictions...
InĀ [12]:
import time
import torch
import numpy as np
from scipy.spatial.distance import directed_hausdorff
from scipy.ndimage import distance_transform_edt
from sklearn.metrics import confusion_matrix
# --- CONFIGURATION ---
MODEL_TO_EVAL = "Best" # Options: "Best" or "SWA"
if MODEL_TO_EVAL == "SWA":
LOAD_PATH = SAVE_DIR + "SatMAE_Hyper_Partial_SWA.pth"
else:
LOAD_PATH = SAVE_DIR + "SatMAE_Hyper_Partial_Best.pth"
# --- HELPER FUNCTIONS ---
def compute_boundary_iou(pred_mask, gt_mask, dilation=2):
"""
Computes IoU specifically along the edges of the fields.
"""
# Create Boundary Maps using simple morphological gradients
from scipy.ndimage import binary_dilation
# Get edges
pred_edges = binary_dilation(pred_mask, iterations=dilation) ^ pred_mask
gt_edges = binary_dilation(gt_mask, iterations=dilation) ^ gt_mask
# Calculate IoU on edges
intersection = (pred_edges & gt_edges).sum()
union = (pred_edges | gt_edges).sum()
if union == 0: return 1.0 # Perfect match (both empty)
return intersection / union
def compute_hausdorff_distance(pred_mask, gt_mask):
"""
Computes 95th Percentile Hausdorff Distance (Robust to outliers).
"""
if pred_mask.sum() == 0 or gt_mask.sum() == 0:
return 0.0 # Edge case handling
# Find coordinates of all '1' pixels
pred_coords = np.argwhere(pred_mask)
gt_coords = np.argwhere(gt_mask)
# Calculate directed distances (A->B and B->A)
d_forward = directed_hausdorff(pred_coords, gt_coords)[0]
d_backward = directed_hausdorff(gt_coords, pred_coords)[0]
return max(d_forward, d_backward)
# --- EVALUATION LOOP ---
def evaluate_model_depth(loader, model_path):
print(f" Starting Deep Evaluation on: {model_path}")
# 1. Load Model
model = SatMAESegmentation(num_frames=6, embed_dim=768).to(device)
checkpoint = torch.load(model_path, map_location=device)
# Robust loading (Handles dict vs state_dict)
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
model.eval()
# 2. Metrics Accumulators
tp_total, fp_total, fn_total, tn_total = 0, 0, 0, 0
boundary_ious = []
hausdorff_dists = []
inference_times = []
print(" Processing Validation Batches...", end="")
with torch.no_grad():
for i, (x, y) in enumerate(loader):
x = x.to(device)
y_true = y.cpu().numpy() # (B, 1, H, W)
# Speed Test
start_time = time.time()
preds = model(x)
torch.cuda.synchronize() if torch.cuda.is_available() else None
end_time = time.time()
inference_times.append((end_time - start_time) / x.shape[0]) # Time per image
# Convert to Binary Masks
probs = torch.sigmoid(preds)
y_pred = (probs > 0.5).float().cpu().numpy()
# --- GLOBAL METRICS (Pixel-wise) ---
# Flatten for confusion matrix
y_true_flat = y_true.flatten().astype(int)
y_pred_flat = y_pred.flatten().astype(int)
tn, fp, fn, tp = confusion_matrix(y_true_flat, y_pred_flat, labels=[0, 1]).ravel()
tp_total += tp; fp_total += fp; fn_total += fn; tn_total += tn
# --- SHAPE METRICS (Image-wise) ---
# We must loop through the batch because Boundary/Hausdorff are per-image
for b in range(x.shape[0]):
p_m = y_pred[b, 0].astype(bool)
g_m = y_true[b, 0].astype(bool)
# Boundary IoU
boundary_ious.append(compute_boundary_iou(p_m, g_m))
# Hausdorff (Only if both have pixels, else skip or penalize)
if p_m.sum() > 0 and g_m.sum() > 0:
hausdorff_dists.append(compute_hausdorff_distance(p_m, g_m))
if i % 10 == 0: print(".", end="")
print("\n Evaluation Complete.")
# 3. CALCULATE FINAL SCORES
epsilon = 1e-6
# Pixel-Based
pixel_acc = (tp_total + tn_total) / (tp_total + tn_total + fp_total + fn_total + epsilon)
iou = tp_total / (tp_total + fp_total + fn_total + epsilon)
precision = tp_total / (tp_total + fp_total + epsilon)
recall = tp_total / (tp_total + fn_total + epsilon)
f1 = 2 * (precision * recall) / (precision + recall + epsilon)
# Shape-Based
avg_boundary_iou = np.mean(boundary_ious)
avg_hausdorff = np.mean(hausdorff_dists)
# Performance
avg_time_per_img = np.mean(inference_times)
fps = 1.0 / avg_time_per_img
# 4. PRINT REPORT
print("\n" + "="*40)
print(f" FINAL METRICS REPORT: {MODEL_TO_EVAL} MODEL")
print("="*40)
results = {
"Model": f"SatMAE_Partial_{MODEL_TO_EVAL}",
"Pixel_Accuracy": round(pixel_acc, 4),
"IoU_Score": round(iou, 4),
"F1_Score": round(f1, 4),
"Precision": round(precision, 4),
"Recall": round(recall, 4),
"Boundary_IoU": round(avg_boundary_iou, 4),
"Hausdorff_Dist_px": round(avg_hausdorff, 2),
"FPS": round(fps, 2),
"Confusion_Matrix": {
"TP": int(tp_total), "FP": int(fp_total),
"FN": int(fn_total), "TN": int(tn_total)
}
}
import json
print(json.dumps(results, indent=4))
return results
# --- RUN EVALUATION ---
final_metrics = evaluate_model_depth(val_loader, LOAD_PATH)
Starting Deep Evaluation on: /content/drive/MyDrive/SatMAE_LongTrain_500/SatMAE_Hyper_Partial_Best.pth Constructing SatMAE Hyper-Stack (6 Time Steps)...
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True warnings.warn(
Adapting Google ViT Weights to 10 Channels...
Weights Adapted.
Applying Partial Fine-Tuning Strategy...
Config: Blocks 1-10 Frozen | Blocks 11-12 Unfrozen
Processing Validation Batches....
Evaluation Complete.
========================================
FINAL METRICS REPORT: Best MODEL
========================================
{
"Model": "SatMAE_Partial_Best",
"Pixel_Accuracy": 0.9102,
"IoU_Score": 0.8625,
"F1_Score": 0.9262,
"Precision": 0.9264,
"Recall": 0.9259,
"Boundary_IoU": 0.2948,
"Hausdorff_Dist_px": 17.63,
"FPS": 10.82,
"Confusion_Matrix": {
"TP": 113079,
"FP": 8979,
"FN": 9051,
"TN": 69595
}
}